This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new e67408c859 KAFKA-10199: Implement removing active and standby tasks 
from the state updater (#12270)
e67408c859 is described below

commit e67408c859fb2a80f1b3c208b7fef6ddc9a711fb
Author: Bruno Cadonna <cado...@apache.org>
AuthorDate: Thu Jun 9 19:28:26 2022 +0200

    KAFKA-10199: Implement removing active and standby tasks from the state 
updater (#12270)
    
    This PR adds removing of active and standby tasks from the default 
implementation of the state updater. The PR also includes refactoring that 
clean up the code.
    
    Reviewers: Guozhang Wang <wangg...@gmail.com>
---
 .../processor/internals/DefaultStateUpdater.java   | 129 ++++---
 .../streams/processor/internals/StateUpdater.java  |  78 ++--
 .../streams/processor/internals/TaskAndAction.java |  67 ++++
 .../internals/DefaultStateUpdaterTest.java         | 419 +++++++++++++++++----
 .../processor/internals/TaskAndActionTest.java     |  68 ++++
 5 files changed, 595 insertions(+), 166 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
index 55935d3e21..54cb7bc427 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
@@ -23,13 +23,13 @@ import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskCorruptedException;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.Task.State;
+import org.apache.kafka.streams.processor.internals.TaskAndAction.Action;
 import org.slf4j.Logger;
 
 import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedList;
 import java.util.List;
@@ -37,6 +37,7 @@ import java.util.Map;
 import java.util.Queue;
 import java.util.Set;
 import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.TimeUnit;
@@ -57,7 +58,7 @@ public class DefaultStateUpdater implements StateUpdater {
         private final ChangelogReader changelogReader;
         private final AtomicBoolean isRunning = new AtomicBoolean(true);
         private final Consumer<Set<TopicPartition>> offsetResetter;
-        private final Map<TaskId, Task> updatingTasks = new HashMap<>();
+        private final Map<TaskId, Task> updatingTasks = new 
ConcurrentHashMap<>();
         private final Logger log;
 
         public StateUpdaterThread(final String name,
@@ -72,7 +73,7 @@ public class DefaultStateUpdater implements StateUpdater {
             log = logContext.logger(DefaultStateUpdater.class);
         }
 
-        public Collection<Task> getAllUpdatingTasks() {
+        public Collection<Task> getUpdatingTasks() {
             return updatingTasks.values();
         }
 
@@ -117,11 +118,13 @@ public class DefaultStateUpdater implements StateUpdater {
             tasksAndActionsLock.lock();
             try {
                 for (final TaskAndAction taskAndAction : getTasksAndActions()) 
{
-                    final Task task = taskAndAction.task;
-                    final Action action = taskAndAction.action;
+                    final Action action = taskAndAction.getAction();
                     switch (action) {
                         case ADD:
-                            addTask(task);
+                            addTask(taskAndAction.getTask());
+                            break;
+                        case REMOVE:
+                            removeTask(taskAndAction.getTaskId());
                             break;
                     }
                 }
@@ -149,7 +152,7 @@ public class DefaultStateUpdater implements StateUpdater {
             log.error("An unexpected error occurred within the state updater 
thread: " + runtimeException);
             final ExceptionAndTasks exceptionAndTasks = new 
ExceptionAndTasks(new HashSet<>(updatingTasks.values()), runtimeException);
             updatingTasks.clear();
-            failedTasks.add(exceptionAndTasks);
+            exceptionsAndFailedTasks.add(exceptionAndTasks);
             isRunning.set(false);
         }
 
@@ -164,7 +167,7 @@ public class DefaultStateUpdater implements StateUpdater {
                 }
                 corruptedTasks.add(corruptedTask);
             }
-            failedTasks.add(new ExceptionAndTasks(corruptedTasks, 
taskCorruptedException));
+            exceptionsAndFailedTasks.add(new ExceptionAndTasks(corruptedTasks, 
taskCorruptedException));
         }
 
         private void handleStreamsException(final StreamsException 
streamsException) {
@@ -175,7 +178,7 @@ public class DefaultStateUpdater implements StateUpdater {
             } else {
                 exceptionAndTasks = 
handleStreamsExceptionWithoutTask(streamsException);
             }
-            failedTasks.add(exceptionAndTasks);
+            exceptionsAndFailedTasks.add(exceptionAndTasks);
         }
 
         private ExceptionAndTasks handleStreamsExceptionWithTask(final 
StreamsException streamsException) {
@@ -230,16 +233,15 @@ public class DefaultStateUpdater implements StateUpdater {
 
         private void addTask(final Task task) {
             if (isStateless(task)) {
-                log.debug("Stateless active task " + task.id() + " was added 
to the state updater");
                 addTaskToRestoredTasks((StreamTask) task);
+                log.debug("Stateless active task " + task.id() + " was added 
to the restored tasks of the state updater");
             } else {
+                updatingTasks.put(task.id(), task);
                 if (task.isActive()) {
-                    updatingTasks.put(task.id(), task);
-                    log.debug("Stateful active task " + task.id() + " was 
added to the state updater");
+                    log.debug("Stateful active task " + task.id() + " was 
added to the updating tasks of the state updater");
                     changelogReader.enforceRestoreActive();
                 } else {
-                    updatingTasks.put(task.id(), task);
-                    log.debug("Standby task " + task.id() + " was added to the 
state updater");
+                    log.debug("Standby task " + task.id() + " was added to the 
updating tasks of the state updater");
                     if (updatingTasks.size() == 1) {
                         changelogReader.transitToUpdateStandby();
                     }
@@ -247,6 +249,19 @@ public class DefaultStateUpdater implements StateUpdater {
             }
         }
 
+        private void removeTask(final TaskId taskId) {
+            final Task task = updatingTasks.remove(taskId);
+            if (task != null) {
+                final Collection<TopicPartition> changelogPartitions = 
task.changelogPartitions();
+                changelogReader.unregister(changelogPartitions);
+                removedTasks.add(task);
+                log.debug((task.isActive() ? "Active" : "Standby")
+                    + " task " + task.id() + " was removed from the updating 
tasks and added to the removed tasks.");
+            } else {
+                log.debug("Task " + taskId + " was not removed since it is not 
updating.");
+            }
+        }
+
         private boolean isStateless(final Task task) {
             return task.changelogPartitions().isEmpty() && task.isActive();
         }
@@ -277,20 +292,6 @@ public class DefaultStateUpdater implements StateUpdater {
         }
     }
 
-    enum Action {
-        ADD
-    }
-
-    private static class TaskAndAction {
-        public final Task task;
-        public final Action action;
-
-        public TaskAndAction(final Task task, final Action action) {
-            this.task = task;
-            this.action = action;
-        }
-    }
-
     private final Time time;
     private final ChangelogReader changelogReader;
     private final Consumer<Set<TopicPartition>> offsetResetter;
@@ -300,7 +301,8 @@ public class DefaultStateUpdater implements StateUpdater {
     private final Queue<StreamTask> restoredActiveTasks = new LinkedList<>();
     private final Lock restoredActiveTasksLock = new ReentrantLock();
     private final Condition restoredActiveTasksCondition = 
restoredActiveTasksLock.newCondition();
-    private final BlockingQueue<ExceptionAndTasks> failedTasks = new 
LinkedBlockingQueue<>();
+    private final BlockingQueue<ExceptionAndTasks> exceptionsAndFailedTasks = 
new LinkedBlockingQueue<>();
+    private final BlockingQueue<Task> removedTasks = new 
LinkedBlockingQueue<>();
     private CountDownLatch shutdownGate;
 
     private StateUpdaterThread stateUpdaterThread = null;
@@ -325,7 +327,7 @@ public class DefaultStateUpdater implements StateUpdater {
 
         tasksAndActionsLock.lock();
         try {
-            tasksAndActions.add(new TaskAndAction(task, Action.ADD));
+            tasksAndActions.add(TaskAndAction.createAddTask(task));
             tasksAndActionsCondition.signalAll();
         } finally {
             tasksAndActionsLock.unlock();
@@ -342,11 +344,18 @@ public class DefaultStateUpdater implements StateUpdater {
     }
 
     @Override
-    public void remove(final Task task) {
+    public void remove(final TaskId taskId) {
+        tasksAndActionsLock.lock();
+        try {
+            tasksAndActions.add(TaskAndAction.createRemoveTask(taskId));
+            tasksAndActionsCondition.signalAll();
+        } finally {
+            tasksAndActionsLock.unlock();
+        }
     }
 
     @Override
-    public Set<StreamTask> getRestoredActiveTasks(final Duration timeout) {
+    public Set<StreamTask> drainRestoredActiveTasks(final Duration timeout) {
         final long timeoutMs = timeout.toMillis();
         final long startTime = time.milliseconds();
         final long deadline = startTime + timeoutMs;
@@ -375,52 +384,42 @@ public class DefaultStateUpdater implements StateUpdater {
     }
 
     @Override
-    public List<ExceptionAndTasks> getFailedTasksAndExceptions() {
+    public Set<Task> drainRemovedTasks() {
+        final List<Task> result = new ArrayList<>();
+        removedTasks.drainTo(result);
+        return new HashSet<>(result);
+    }
+
+    @Override
+    public List<ExceptionAndTasks> drainExceptionsAndFailedTasks() {
         final List<ExceptionAndTasks> result = new ArrayList<>();
-        failedTasks.drainTo(result);
+        exceptionsAndFailedTasks.drainTo(result);
         return result;
     }
 
-    @Override
-    public Set<Task> getAllTasks() {
-        tasksAndActionsLock.lock();
+    public Set<StandbyTask> getUpdatingStandbyTasks() {
+        return Collections.unmodifiableSet(new 
HashSet<>(stateUpdaterThread.getUpdatingStandbyTasks()));
+    }
+
+    public Set<Task> getUpdatingTasks() {
+        return Collections.unmodifiableSet(new 
HashSet<>(stateUpdaterThread.getUpdatingTasks()));
+    }
+
+    public Set<StreamTask> getRestoredActiveTasks() {
         restoredActiveTasksLock.lock();
         try {
-            final Set<Task> allTasks = new HashSet<>();
-            allTasks.addAll(tasksAndActions.stream()
-                .filter(t -> t.action == Action.ADD)
-                .map(t -> t.task)
-                .collect(Collectors.toList())
-            );
-            allTasks.addAll(stateUpdaterThread.getAllUpdatingTasks());
-            allTasks.addAll(restoredActiveTasks);
-            return Collections.unmodifiableSet(allTasks);
+            return Collections.unmodifiableSet(new 
HashSet<>(restoredActiveTasks));
         } finally {
             restoredActiveTasksLock.unlock();
-            tasksAndActionsLock.unlock();
         }
     }
 
-    @Override
-    public Set<StandbyTask> getStandbyTasks() {
-        tasksAndActionsLock.lock();
-        try {
-            final Set<StandbyTask> standbyTasks = new HashSet<>();
-            standbyTasks.addAll(tasksAndActions.stream()
-                .filter(t -> t.action == Action.ADD)
-                .filter(t -> !t.task.isActive())
-                .map(t -> (StandbyTask) t.task)
-                .collect(Collectors.toList())
-            );
-            standbyTasks.addAll(getUpdatingStandbyTasks());
-            return Collections.unmodifiableSet(standbyTasks);
-        } finally {
-            tasksAndActionsLock.unlock();
-        }
+    public List<ExceptionAndTasks> getExceptionsAndFailedTasks() {
+        return Collections.unmodifiableList(new 
ArrayList<>(exceptionsAndFailedTasks));
     }
 
-    public Set<StandbyTask> getUpdatingStandbyTasks() {
-        return Collections.unmodifiableSet(new 
HashSet<>(stateUpdaterThread.getUpdatingStandbyTasks()));
+    public Set<Task> getRemovedTasks() {
+        return Collections.unmodifiableSet(new HashSet<>(removedTasks));
     }
 
     @Override
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
index d2d4ab71ad..42e65d4adb 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
@@ -16,65 +16,101 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import org.apache.kafka.streams.processor.TaskId;
+
 import java.time.Duration;
+import java.util.Collections;
 import java.util.List;
+import java.util.Objects;
 import java.util.Set;
 
 public interface StateUpdater {
 
     class ExceptionAndTasks {
-        public final Set<Task> tasks;
-        public final RuntimeException exception;
+        private final Set<Task> tasks;
+        private final RuntimeException exception;
 
         public ExceptionAndTasks(final Set<Task> tasks, final RuntimeException 
exception) {
-            this.tasks = tasks;
-            this.exception = exception;
+            this.tasks = Objects.requireNonNull(tasks);
+            this.exception = Objects.requireNonNull(exception);
+        }
+
+        public Set<Task> tasks() {
+            return Collections.unmodifiableSet(tasks);
+        }
+
+        public RuntimeException exception() {
+            return exception;
+        }
+
+        @Override
+        public boolean equals(final Object o) {
+            if (this == o) return true;
+            if (!(o instanceof ExceptionAndTasks)) return false;
+            final ExceptionAndTasks that = (ExceptionAndTasks) o;
+            return tasks.equals(that.tasks) && 
exception.equals(that.exception);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(tasks, exception);
         }
     }
 
     /**
      * Adds a task (active or standby) to the state updater.
      *
+     * This method does not block until the task is added to the state updater.
+     *
      * @param task task to add
      */
     void add(final Task task);
 
     /**
-     * Removes a task (active or standby) from the state updater.
+     * Removes a task (active or standby) from the state updater and adds the 
removed task to the removed tasks.
+     *
+     * This method does not block until the removed task is removed from the 
state updater.
      *
-     * @param task task ro remove
+     * The task to be removed is not removed from the restored active tasks 
and the failed tasks.
+     * Stateless tasks will never be added to the removed tasks since they are 
immediately added to the
+     * restored active tasks.
+     *
+     * @param taskId ID of the task to remove
      */
-    void remove(final Task task);
+    void remove(final TaskId taskId);
 
     /**
-     * Gets restored active tasks from state restoration/update
+     * Drains the restored active tasks from the state updater.
+     *
+     * The returned active tasks are removed from the state updater.
      *
      * @param timeout duration how long the calling thread should wait for 
restored active tasks
      *
      * @return set of active tasks with up-to-date states
      */
-    Set<StreamTask> getRestoredActiveTasks(final Duration timeout);
+    Set<StreamTask> drainRestoredActiveTasks(final Duration timeout);
 
-    /**
-     * Gets failed tasks and the corresponding exceptions
-     *
-     * @return list of failed tasks and the corresponding exceptions
-     */
-    List<ExceptionAndTasks> getFailedTasksAndExceptions();
 
     /**
-     * Get all tasks (active and standby) that are managed by the state 
updater.
+     * Drains the removed tasks (active and standbys) from the state updater.
+     *
+     * Removed tasks returned by this method are tasks extraordinarily removed 
from the state updater. These do not
+     * include restored or failed tasks.
+     *
+     * The returned removed tasks are removed from the state updater
      *
-     * @return set of tasks managed by the state updater
+     * @return set of tasks removed from the state updater
      */
-    Set<Task> getAllTasks();
+    Set<Task> drainRemovedTasks();
 
     /**
-     * Get standby tasks that are managed by the state updater.
+     * Drains the failed tasks and the corresponding exceptions.
      *
-     * @return set of standby tasks managed by the state updater
+     * The returned failed tasks are removed from the state updater
+     *
+     * @return list of failed tasks and the corresponding exceptions
      */
-    Set<StandbyTask> getStandbyTasks();
+    List<ExceptionAndTasks> drainExceptionsAndFailedTasks();
 
     /**
      * Shuts down the state updater.
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAndAction.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAndAction.java
new file mode 100644
index 0000000000..4c4316a864
--- /dev/null
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAndAction.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.streams.processor.TaskId;
+
+import java.util.Objects;
+
+public class TaskAndAction {
+
+    enum Action {
+        ADD,
+        REMOVE
+    }
+
+    private final Task task;
+    private final TaskId taskId;
+    private final Action action;
+
+    private TaskAndAction(final Task task, final TaskId taskId, final Action 
action) {
+        this.task = task;
+        this.taskId = taskId;
+        this.action = action;
+    }
+
+    public static TaskAndAction createAddTask(final Task task) {
+        Objects.requireNonNull(task, "Task to add is null!");
+        return new TaskAndAction(task, null, Action.ADD);
+    }
+
+    public static TaskAndAction createRemoveTask(final TaskId taskId) {
+        Objects.requireNonNull(taskId, "Task ID of task to remove is null!");
+        return new TaskAndAction(null, taskId, Action.REMOVE);
+    }
+
+    public Task getTask() {
+        if (action != Action.ADD) {
+            throw new IllegalStateException("Action type " + action + " cannot 
have a task!");
+        }
+        return task;
+    }
+
+    public TaskId getTaskId() {
+        if (action != Action.REMOVE) {
+            throw new IllegalStateException("Action type " + action + " cannot 
have a task ID!");
+        }
+        return taskId;
+    }
+
+    public Action getAction() {
+        return action;
+    }
+}
\ No newline at end of file
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
index c9fa1abede..fa50380f7f 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
@@ -36,14 +36,12 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.stream.Collectors;
 
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkSet;
 import static org.apache.kafka.test.TestUtils.waitForCondition;
 import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.mockito.ArgumentMatchers.anyMap;
@@ -154,7 +152,9 @@ class DefaultStateUpdaterTest {
         }
 
         verifyRestoredActiveTasks(tasks);
-        assertTrue(stateUpdater.getAllTasks().isEmpty());
+        verifyUpdatingTasks();
+        verifyExceptionsAndFailedTasks();
+        verifyRemovedTasks();
     }
 
     @Test
@@ -173,9 +173,11 @@ class DefaultStateUpdaterTest {
         stateUpdater.add(task);
 
         verifyRestoredActiveTasks(task);
-        assertTrue(stateUpdater.getAllTasks().isEmpty());
+        verifyUpdatingTasks();
+        verifyExceptionsAndFailedTasks();
+        verifyRemovedTasks();
         verify(changelogReader, times(1)).enforceRestoreActive();
-        verify(changelogReader, atLeast(1)).restore(anyMap());
+        verify(changelogReader, atLeast(3)).restore(anyMap());
         verify(task).completeRestoration(offsetResetter);
         verify(changelogReader, never()).transitToUpdateStandby();
     }
@@ -201,7 +203,9 @@ class DefaultStateUpdaterTest {
         stateUpdater.add(task3);
 
         verifyRestoredActiveTasks(task3, task1, task2);
-        assertTrue(stateUpdater.getAllTasks().isEmpty());
+        verifyUpdatingTasks();
+        verifyExceptionsAndFailedTasks();
+        verifyRemovedTasks();
         verify(changelogReader, times(3)).enforceRestoreActive();
         verify(changelogReader, atLeast(4)).restore(anyMap());
         verify(task3).completeRestoration(offsetResetter);
@@ -210,6 +214,25 @@ class DefaultStateUpdaterTest {
         verify(changelogReader, never()).transitToUpdateStandby();
     }
 
+    @Test
+    public void shouldDrainRestoredActiveTasks() throws Exception {
+        
assertTrue(stateUpdater.drainRestoredActiveTasks(Duration.ZERO).isEmpty());
+
+        final StreamTask task1 = createStatelessTaskInStateRestoring(TASK_0_0);
+        stateUpdater.add(task1);
+
+        verifyDrainingRestoredActiveTasks(task1);
+
+        final StreamTask task2 = createStatelessTaskInStateRestoring(TASK_1_1);
+        final StreamTask task3 = createStatelessTaskInStateRestoring(TASK_1_0);
+        final StreamTask task4 = createStatelessTaskInStateRestoring(TASK_0_2);
+        stateUpdater.add(task2);
+        stateUpdater.add(task3);
+        stateUpdater.add(task4);
+
+        verifyDrainingRestoredActiveTasks(task2, task3, task4);
+    }
+
     @Test
     public void shouldUpdateSingleStandbyTask() throws Exception {
         final StandbyTask task = createStandbyTaskInStateRunning(
@@ -236,6 +259,9 @@ class DefaultStateUpdaterTest {
         }
 
         verifyUpdatingStandbyTasks(tasks);
+        verifyRestoredActiveTasks();
+        verifyExceptionsAndFailedTasks();
+        verifyRemovedTasks();
         verify(changelogReader, times(1)).transitToUpdateStandby();
         verify(changelogReader, 
timeout(VERIFICATION_TIMEOUT).atLeast(1)).restore(anyMap());
         verify(changelogReader, never()).enforceRestoreActive();
@@ -260,10 +286,12 @@ class DefaultStateUpdaterTest {
         stateUpdater.add(task4);
 
         verifyRestoredActiveTasks(task2, task1);
+        verifyUpdatingStandbyTasks(task4, task3);
+        verifyExceptionsAndFailedTasks();
+        verifyRemovedTasks();
         verify(task1).completeRestoration(offsetResetter);
         verify(task2).completeRestoration(offsetResetter);
         verify(changelogReader, atLeast(3)).restore(anyMap());
-        verifyUpdatingStandbyTasks(task4, task3);
         final InOrder orderVerifier = inOrder(changelogReader, task1, task2);
         orderVerifier.verify(changelogReader, times(2)).enforceRestoreActive();
         orderVerifier.verify(changelogReader, 
times(1)).transitToUpdateStandby();
@@ -293,37 +321,155 @@ class DefaultStateUpdaterTest {
 
         stateUpdater.add(task3);
 
-        verifyRestoredActiveTasks(task3);
+        verifyRestoredActiveTasks(task1, task3);
         verify(task3).completeRestoration(offsetResetter);
         orderVerifier.verify(changelogReader, times(1)).enforceRestoreActive();
         orderVerifier.verify(changelogReader, 
times(1)).transitToUpdateStandby();
     }
 
+    @Test
+    public void shouldRemoveActiveStatefulTask() throws Exception {
+        final StreamTask task = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
+        shouldRemoveStatefulTask(task);
+    }
+
+    @Test
+    public void shouldRemoveStandbyTask() throws Exception {
+        final StandbyTask task = createStandbyTaskInStateRunning(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
+        shouldRemoveStatefulTask(task);
+    }
+
+    private void shouldRemoveStatefulTask(final Task task) throws Exception {
+        when(changelogReader.completedChangelogs())
+            .thenReturn(Collections.emptySet());
+        when(changelogReader.allChangelogsCompleted())
+            .thenReturn(false);
+        stateUpdater.add(task);
+
+        stateUpdater.remove(TASK_0_0);
+
+        verifyRemovedTasks(task);
+        verifyRestoredActiveTasks();
+        verifyUpdatingTasks();
+        verifyExceptionsAndFailedTasks();
+        
verify(changelogReader).unregister(Collections.singletonList(TOPIC_PARTITION_A_0));
+    }
+
+    @Test
+    public void shouldNotRemoveActiveStatefulTaskFromRestoredActiveTasks() 
throws Exception {
+        final StreamTask task = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
+        shouldNotRemoveTaskFromRestoredActiveTasks(task);
+    }
+
+    @Test
+    public void shouldNotRemoveStatelessTaskFromRestoredActiveTasks() throws 
Exception {
+        final StreamTask task = createStatelessTaskInStateRestoring(TASK_0_0);
+        shouldNotRemoveTaskFromRestoredActiveTasks(task);
+    }
+
+    private void shouldNotRemoveTaskFromRestoredActiveTasks(final StreamTask 
task) throws Exception {
+        final StreamTask controlTask = 
createActiveStatefulTaskInStateRestoring(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_B_0));
+        when(changelogReader.completedChangelogs())
+            .thenReturn(Collections.singleton(TOPIC_PARTITION_A_0));
+        when(changelogReader.allChangelogsCompleted())
+            .thenReturn(false);
+        stateUpdater.add(task);
+        stateUpdater.add(controlTask);
+        verifyRestoredActiveTasks(task);
+
+        stateUpdater.remove(task.id());
+        stateUpdater.remove(controlTask.id());
+
+        verifyRemovedTasks(controlTask);
+        verifyRestoredActiveTasks(task);
+        verifyUpdatingTasks();
+        verifyExceptionsAndFailedTasks();
+    }
+
+    @Test
+    public void shouldNotRemoveActiveStatefulTaskFromFailedTasks() throws 
Exception {
+        final StreamTask task = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
+        shouldNotRemoveTaskFromFailedTasks(task);
+    }
+
+    @Test
+    public void shouldNotRemoveStandbyTaskFromFailedTasks() throws Exception {
+        final StandbyTask task = createStandbyTaskInStateRunning(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
+        shouldNotRemoveTaskFromFailedTasks(task);
+    }
+
+    private void shouldNotRemoveTaskFromFailedTasks(final Task task) throws 
Exception {
+        final StreamTask controlTask = 
createActiveStatefulTaskInStateRestoring(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_B_0));
+        final StreamsException streamsException = new 
StreamsException("Something happened", task.id());
+        when(changelogReader.completedChangelogs())
+            .thenReturn(Collections.emptySet());
+        when(changelogReader.allChangelogsCompleted())
+            .thenReturn(false);
+        doNothing()
+            .doThrow(streamsException)
+            .doNothing()
+            .when(changelogReader).restore(anyMap());
+        stateUpdater.add(task);
+        stateUpdater.add(controlTask);
+        final ExceptionAndTasks expectedExceptionAndTasks = new 
ExceptionAndTasks(mkSet(task), streamsException);
+        verifyExceptionsAndFailedTasks(expectedExceptionAndTasks);
+
+        stateUpdater.remove(task.id());
+        stateUpdater.remove(controlTask.id());
+
+        verifyRemovedTasks(controlTask);
+        verifyExceptionsAndFailedTasks(expectedExceptionAndTasks);
+        verifyUpdatingTasks();
+        verifyRestoredActiveTasks();
+    }
+
+    @Test
+    public void shouldDrainRemovedTasks() throws Exception {
+        assertTrue(stateUpdater.drainRemovedTasks().isEmpty());
+        when(changelogReader.completedChangelogs())
+            .thenReturn(Collections.emptySet());
+        when(changelogReader.allChangelogsCompleted())
+            .thenReturn(false);
+
+        final StreamTask task1 = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_B_0));
+        stateUpdater.add(task1);
+        stateUpdater.remove(task1.id());
+
+        verifyDrainingRemovedTasks(task1);
+
+        final StreamTask task2 = 
createActiveStatefulTaskInStateRestoring(TASK_1_1, 
Collections.singletonList(TOPIC_PARTITION_C_0));
+        final StreamTask task3 = 
createActiveStatefulTaskInStateRestoring(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
+        final StreamTask task4 = 
createActiveStatefulTaskInStateRestoring(TASK_0_2, 
Collections.singletonList(TOPIC_PARTITION_D_0));
+        stateUpdater.add(task2);
+        stateUpdater.remove(task2.id());
+        stateUpdater.add(task3);
+        stateUpdater.remove(task3.id());
+        stateUpdater.add(task4);
+        stateUpdater.remove(task4.id());
+
+        verifyDrainingRemovedTasks(task2, task3, task4);
+    }
+
     @Test
     public void 
shouldAddFailedTasksToQueueWhenRestoreThrowsStreamsExceptionWithoutTask() 
throws Exception {
         final StreamTask task1 = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
         final StandbyTask task2 = createStandbyTaskInStateRunning(TASK_0_2, 
Collections.singletonList(TOPIC_PARTITION_B_0));
-        final String expectedMessage = "The Streams were crossed!";
-        final StreamsException expectedStreamsException = new 
StreamsException(expectedMessage);
+        final String exceptionMessage = "The Streams were crossed!";
+        final StreamsException streamsException = new 
StreamsException(exceptionMessage);
         final Map<TaskId, Task> updatingTasks = mkMap(
             mkEntry(task1.id(), task1),
             mkEntry(task2.id(), task2)
         );
-        
doNothing().doThrow(expectedStreamsException).doNothing().when(changelogReader).restore(updatingTasks);
+        
doNothing().doThrow(streamsException).when(changelogReader).restore(updatingTasks);
 
         stateUpdater.add(task1);
         stateUpdater.add(task2);
 
-        final List<ExceptionAndTasks> failedTasks = getFailedTasks(1);
-        assertEquals(1, failedTasks.size());
-        final ExceptionAndTasks actualFailedTasks = failedTasks.get(0);
-        assertEquals(2, actualFailedTasks.tasks.size());
-        assertTrue(actualFailedTasks.tasks.containsAll(Arrays.asList(task1, 
task2)));
-        assertTrue(actualFailedTasks.exception instanceof StreamsException);
-        final StreamsException actualException = (StreamsException) 
actualFailedTasks.exception;
-        assertFalse(actualException.taskId().isPresent());
-        assertEquals(expectedMessage, actualException.getMessage());
-        assertTrue(stateUpdater.getAllTasks().isEmpty());
+        final ExceptionAndTasks expectedExceptionAndTasks = new 
ExceptionAndTasks(mkSet(task1, task2), streamsException);
+        verifyExceptionsAndFailedTasks(expectedExceptionAndTasks);
+        verifyRemovedTasks();
+        verifyUpdatingTasks();
+        verifyRestoredActiveTasks();
     }
 
     @Test
@@ -331,9 +477,9 @@ class DefaultStateUpdaterTest {
         final StreamTask task1 = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
         final StreamTask task2 = 
createActiveStatefulTaskInStateRestoring(TASK_0_2, 
Collections.singletonList(TOPIC_PARTITION_B_0));
         final StandbyTask task3 = createStandbyTaskInStateRunning(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_C_0));
-        final String expectedMessage = "The Streams were crossed!";
-        final StreamsException expectedStreamsException1 = new 
StreamsException(expectedMessage, task1.id());
-        final StreamsException expectedStreamsException2 = new 
StreamsException(expectedMessage, task3.id());
+        final String exceptionMessage = "The Streams were crossed!";
+        final StreamsException streamsException1 = new 
StreamsException(exceptionMessage, task1.id());
+        final StreamsException streamsException2 = new 
StreamsException(exceptionMessage, task3.id());
         final Map<TaskId, Task> updatingTasksBeforeFirstThrow = mkMap(
             mkEntry(task1.id(), task1),
             mkEntry(task2.id(), task2),
@@ -344,36 +490,22 @@ class DefaultStateUpdaterTest {
             mkEntry(task3.id(), task3)
         );
         doNothing()
-            .doThrow(expectedStreamsException1)
+            .doThrow(streamsException1)
             .when(changelogReader).restore(updatingTasksBeforeFirstThrow);
         doNothing()
-            .doThrow(expectedStreamsException2)
+            .doThrow(streamsException2)
             .when(changelogReader).restore(updatingTasksBeforeSecondThrow);
 
         stateUpdater.add(task1);
         stateUpdater.add(task2);
         stateUpdater.add(task3);
 
-        final List<ExceptionAndTasks> failedTasks = getFailedTasks(2);
-        assertEquals(2, failedTasks.size());
-        final ExceptionAndTasks actualFailedTasks1 = failedTasks.get(0);
-        assertEquals(1, actualFailedTasks1.tasks.size());
-        assertTrue(actualFailedTasks1.tasks.contains(task1));
-        assertTrue(actualFailedTasks1.exception instanceof StreamsException);
-        final StreamsException actualException1 = (StreamsException) 
actualFailedTasks1.exception;
-        assertTrue(actualException1.taskId().isPresent());
-        assertEquals(task1.id(), actualException1.taskId().get());
-        assertEquals(expectedMessage, actualException1.getMessage());
-        final ExceptionAndTasks actualFailedTasks2 = failedTasks.get(1);
-        assertEquals(1, actualFailedTasks2.tasks.size());
-        assertTrue(actualFailedTasks2.tasks.contains(task3));
-        assertTrue(actualFailedTasks2.exception instanceof StreamsException);
-        final StreamsException actualException2 = (StreamsException) 
actualFailedTasks2.exception;
-        assertTrue(actualException2.taskId().isPresent());
-        assertEquals(task3.id(), actualException2.taskId().get());
-        assertEquals(expectedMessage, actualException2.getMessage());
-        assertEquals(1, stateUpdater.getAllTasks().size());
-        assertTrue(stateUpdater.getAllTasks().contains(task2));
+        final ExceptionAndTasks expectedExceptionAndTasks1 = new 
ExceptionAndTasks(mkSet(task1), streamsException1);
+        final ExceptionAndTasks expectedExceptionAndTasks2 = new 
ExceptionAndTasks(mkSet(task3), streamsException2);
+        verifyExceptionsAndFailedTasks(expectedExceptionAndTasks1, 
expectedExceptionAndTasks2);
+        verifyUpdatingTasks(task2);
+        verifyRestoredActiveTasks();
+        verifyRemovedTasks();
     }
 
     @Test
@@ -394,18 +526,11 @@ class DefaultStateUpdaterTest {
         stateUpdater.add(task2);
         stateUpdater.add(task3);
 
-        final List<ExceptionAndTasks> failedTasks = getFailedTasks(1);
-        assertEquals(1, failedTasks.size());
-        final List<Task> expectedFailedTasks = Arrays.asList(task1, task2);
-        final ExceptionAndTasks actualFailedTasks = failedTasks.get(0);
-        assertEquals(2, actualFailedTasks.tasks.size());
-        assertTrue(actualFailedTasks.tasks.containsAll(expectedFailedTasks));
-        assertTrue(actualFailedTasks.exception instanceof 
TaskCorruptedException);
-        final TaskCorruptedException actualException = 
(TaskCorruptedException) actualFailedTasks.exception;
-        final Set<TaskId> corruptedTasks = actualException.corruptedTasks();
-        
assertTrue(corruptedTasks.containsAll(expectedFailedTasks.stream().map(Task::id).collect(Collectors.toList())));
-        assertEquals(1, stateUpdater.getAllTasks().size());
-        assertTrue(stateUpdater.getAllTasks().contains(task3));
+        final ExceptionAndTasks expectedExceptionAndTasks = new 
ExceptionAndTasks(mkSet(task1, task2), taskCorruptedException);
+        verifyExceptionsAndFailedTasks(expectedExceptionAndTasks);
+        verifyUpdatingTasks(task3);
+        verifyRestoredActiveTasks();
+        verifyRemovedTasks();
     }
 
     @Test
@@ -422,30 +547,113 @@ class DefaultStateUpdaterTest {
         stateUpdater.add(task1);
         stateUpdater.add(task2);
 
-        final List<ExceptionAndTasks> failedTasks = getFailedTasks(1);
-        final List<Task> expectedFailedTasks = Arrays.asList(task1, task2);
-        final ExceptionAndTasks actualFailedTasks = failedTasks.get(0);
-        assertEquals(2, actualFailedTasks.tasks.size());
-        assertTrue(actualFailedTasks.tasks.containsAll(expectedFailedTasks));
-        assertTrue(actualFailedTasks.exception instanceof 
IllegalStateException);
-        final IllegalStateException actualException = (IllegalStateException) 
actualFailedTasks.exception;
-        assertEquals(actualException.getMessage(), 
illegalStateException.getMessage());
-        assertTrue(stateUpdater.getAllTasks().isEmpty());
+        final ExceptionAndTasks expectedExceptionAndTasks = new 
ExceptionAndTasks(mkSet(task1, task2), illegalStateException);
+        verifyExceptionsAndFailedTasks(expectedExceptionAndTasks);
+        verifyUpdatingTasks();
+        verifyRestoredActiveTasks();
+        verifyRemovedTasks();
+    }
+
+    @Test
+    public void shouldDrainFailedTasksAndExceptions() throws Exception {
+        assertTrue(stateUpdater.drainExceptionsAndFailedTasks().isEmpty());
+
+        final StreamTask task1 = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_B_0));
+        final StreamTask task2 = 
createActiveStatefulTaskInStateRestoring(TASK_1_1, 
Collections.singletonList(TOPIC_PARTITION_C_0));
+        final StreamTask task3 = 
createActiveStatefulTaskInStateRestoring(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
+        final StreamTask task4 = 
createActiveStatefulTaskInStateRestoring(TASK_0_2, 
Collections.singletonList(TOPIC_PARTITION_D_0));
+        final String exceptionMessage = "The Streams were crossed!";
+        final StreamsException streamsException1 = new 
StreamsException(exceptionMessage, task1.id());
+        final Map<TaskId, Task> updatingTasks1 = mkMap(
+            mkEntry(task1.id(), task1)
+        );
+        doThrow(streamsException1)
+            .when(changelogReader).restore(updatingTasks1);
+        final StreamsException streamsException2 = new 
StreamsException(exceptionMessage, task2.id());
+        final StreamsException streamsException3 = new 
StreamsException(exceptionMessage, task3.id());
+        final StreamsException streamsException4 = new 
StreamsException(exceptionMessage, task4.id());
+        final Map<TaskId, Task> updatingTasks2 = mkMap(
+            mkEntry(task2.id(), task2),
+            mkEntry(task3.id(), task3),
+            mkEntry(task4.id(), task4)
+        );
+        
doThrow(streamsException2).when(changelogReader).restore(updatingTasks2);
+        final Map<TaskId, Task> updatingTasks3 = mkMap(
+            mkEntry(task3.id(), task3),
+            mkEntry(task4.id(), task4)
+        );
+        
doThrow(streamsException3).when(changelogReader).restore(updatingTasks3);
+        final Map<TaskId, Task> updatingTasks4 = mkMap(
+            mkEntry(task4.id(), task4)
+        );
+        
doThrow(streamsException4).when(changelogReader).restore(updatingTasks4);
+
+        stateUpdater.add(task1);
+
+        final ExceptionAndTasks expectedExceptionAndTasks1 = new 
ExceptionAndTasks(mkSet(task1), streamsException1);
+        verifyDrainingExceptionsAndFailedTasks(expectedExceptionAndTasks1);
+
+        stateUpdater.add(task2);
+        stateUpdater.add(task3);
+        stateUpdater.add(task4);
+
+        final ExceptionAndTasks expectedExceptionAndTasks2 = new 
ExceptionAndTasks(mkSet(task2), streamsException2);
+        final ExceptionAndTasks expectedExceptionAndTasks3 = new 
ExceptionAndTasks(mkSet(task3), streamsException3);
+        final ExceptionAndTasks expectedExceptionAndTasks4 = new 
ExceptionAndTasks(mkSet(task4), streamsException4);
+        verifyDrainingExceptionsAndFailedTasks(expectedExceptionAndTasks2, 
expectedExceptionAndTasks3, expectedExceptionAndTasks4);
     }
 
     private void verifyRestoredActiveTasks(final StreamTask... tasks) throws 
Exception {
+        if (tasks.length == 0) {
+            assertTrue(stateUpdater.getRestoredActiveTasks().isEmpty());
+        } else {
+            final Set<StreamTask> expectedRestoredTasks = mkSet(tasks);
+            final Set<StreamTask> restoredTasks = new HashSet<>();
+            waitForCondition(
+                () -> {
+                    
restoredTasks.addAll(stateUpdater.getRestoredActiveTasks());
+                    return restoredTasks.containsAll(expectedRestoredTasks);
+                },
+                VERIFICATION_TIMEOUT,
+                "Did not get all restored active task within the given 
timeout!"
+            );
+            assertEquals(expectedRestoredTasks.size(), restoredTasks.size());
+            assertTrue(restoredTasks.stream().allMatch(task -> task.state() == 
State.RESTORING));
+        }
+    }
+
+    private void verifyDrainingRestoredActiveTasks(final StreamTask... tasks) 
throws Exception {
         final Set<StreamTask> expectedRestoredTasks = mkSet(tasks);
         final Set<StreamTask> restoredTasks = new HashSet<>();
         waitForCondition(
             () -> {
-                
restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT)));
-                return restoredTasks.size() == expectedRestoredTasks.size();
+                
restoredTasks.addAll(stateUpdater.drainRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT)));
+                return restoredTasks.containsAll(expectedRestoredTasks);
             },
             VERIFICATION_TIMEOUT,
-            "Did not get any restored active task within the given timeout!"
+            "Did not get all restored active task within the given timeout!"
         );
-        assertTrue(restoredTasks.containsAll(expectedRestoredTasks));
-        assertEquals(expectedRestoredTasks.size(), 
restoredTasks.stream().filter(task -> task.state() == State.RESTORING).count());
+        assertEquals(expectedRestoredTasks.size(), restoredTasks.size());
+        
assertTrue(stateUpdater.drainRestoredActiveTasks(Duration.ZERO).isEmpty());
+    }
+
+    private void verifyUpdatingTasks(final Task... tasks) throws Exception {
+        if (tasks.length == 0) {
+            assertTrue(stateUpdater.getUpdatingTasks().isEmpty());
+        } else {
+            final Set<Task> expectedUpdatingTasks = mkSet(tasks);
+            final Set<Task> updatingTasks = new HashSet<>();
+            waitForCondition(
+                () -> {
+                    updatingTasks.addAll(stateUpdater.getUpdatingTasks());
+                    return updatingTasks.containsAll(expectedUpdatingTasks);
+                },
+                VERIFICATION_TIMEOUT,
+                "Did not get all updating task within the given timeout!"
+            );
+            assertEquals(expectedUpdatingTasks.size(), updatingTasks.size());
+            assertTrue(updatingTasks.stream().allMatch(task -> task.state() == 
State.RESTORING));
+        }
     }
 
     private void verifyUpdatingStandbyTasks(final StandbyTask... tasks) throws 
Exception {
@@ -454,27 +662,78 @@ class DefaultStateUpdaterTest {
         waitForCondition(
             () -> {
                 standbyTasks.addAll(stateUpdater.getUpdatingStandbyTasks());
-                return standbyTasks.size() == expectedStandbyTasks.size();
+                return standbyTasks.containsAll(expectedStandbyTasks);
             },
             VERIFICATION_TIMEOUT,
             "Did not see all standby task within the given timeout!"
         );
-        assertTrue(standbyTasks.containsAll(expectedStandbyTasks));
-        assertEquals(expectedStandbyTasks.size(), 
standbyTasks.stream().filter(t -> t.state() == State.RUNNING).count());
+        assertEquals(expectedStandbyTasks.size(), standbyTasks.size());
+        assertTrue(standbyTasks.stream().allMatch(task -> task.state() == 
State.RUNNING));
+    }
+
+    private void verifyRemovedTasks(final Task... tasks) throws Exception {
+        if (tasks.length == 0) {
+            assertTrue(stateUpdater.getRemovedTasks().isEmpty());
+        } else {
+            final Set<Task> expectedRemovedTasks = mkSet(tasks);
+            final Set<Task> removedTasks = new HashSet<>();
+            waitForCondition(
+                () -> {
+                    removedTasks.addAll(stateUpdater.getRemovedTasks());
+                    return removedTasks.containsAll(mkSet(tasks));
+                },
+                VERIFICATION_TIMEOUT,
+                "Did not get all removed task within the given timeout!"
+            );
+            assertEquals(expectedRemovedTasks.size(), removedTasks.size());
+            assertTrue(removedTasks.stream()
+                .allMatch(task -> task.isActive() && task.state() == 
State.RESTORING
+                        || !task.isActive() && task.state() == State.RUNNING));
+        }
     }
 
-    private List<ExceptionAndTasks> getFailedTasks(final int expectedCount) 
throws Exception {
+    private void verifyDrainingRemovedTasks(final Task... tasks) throws 
Exception {
+        final Set<Task> expectedRemovedTasks = mkSet(tasks);
+        final Set<Task> removedTasks = new HashSet<>();
+        waitForCondition(
+            () -> {
+                removedTasks.addAll(stateUpdater.drainRemovedTasks());
+                return removedTasks.containsAll(mkSet(tasks));
+            },
+            VERIFICATION_TIMEOUT,
+            "Did not get all restored active task within the given timeout!"
+        );
+        assertEquals(expectedRemovedTasks.size(), removedTasks.size());
+        assertTrue(stateUpdater.drainRemovedTasks().isEmpty());
+    }
+
+    private void verifyExceptionsAndFailedTasks(final ExceptionAndTasks... 
exceptionsAndTasks) throws Exception {
+        final List<ExceptionAndTasks> expectedExceptionAndTasks = 
Arrays.asList(exceptionsAndTasks);
         final List<ExceptionAndTasks> failedTasks = new ArrayList<>();
         waitForCondition(
             () -> {
-                failedTasks.addAll(stateUpdater.getFailedTasksAndExceptions());
-                return failedTasks.size() >= expectedCount;
+                failedTasks.addAll(stateUpdater.getExceptionsAndFailedTasks());
+                return failedTasks.containsAll(expectedExceptionAndTasks);
             },
             VERIFICATION_TIMEOUT,
-            "Did not get enough failed tasks within the given timeout!"
+            "Did not get all exceptions and failed tasks within the given 
timeout!"
         );
+        assertEquals(expectedExceptionAndTasks.size(), failedTasks.size());
+    }
 
-        return failedTasks;
+    private void verifyDrainingExceptionsAndFailedTasks(final 
ExceptionAndTasks... exceptionsAndTasks) throws Exception {
+        final List<ExceptionAndTasks> expectedExceptionAndTasks = 
Arrays.asList(exceptionsAndTasks);
+        final List<ExceptionAndTasks> failedTasks = new ArrayList<>();
+        waitForCondition(
+            () -> {
+                
failedTasks.addAll(stateUpdater.drainExceptionsAndFailedTasks());
+                return failedTasks.containsAll(expectedExceptionAndTasks);
+            },
+            VERIFICATION_TIMEOUT,
+            "Did not get all exceptions and failed tasks within the given 
timeout!"
+        );
+        assertEquals(expectedExceptionAndTasks.size(), failedTasks.size());
+        assertTrue(stateUpdater.drainExceptionsAndFailedTasks().isEmpty());
     }
 
     private StreamTask createActiveStatefulTaskInStateRestoring(final TaskId 
taskId,
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskAndActionTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskAndActionTest.java
new file mode 100644
index 0000000000..39b927ee09
--- /dev/null
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskAndActionTest.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.streams.processor.TaskId;
+import org.junit.jupiter.api.Test;
+
+import static 
org.apache.kafka.streams.processor.internals.TaskAndAction.Action.ADD;
+import static 
org.apache.kafka.streams.processor.internals.TaskAndAction.Action.REMOVE;
+import static 
org.apache.kafka.streams.processor.internals.TaskAndAction.createAddTask;
+import static 
org.apache.kafka.streams.processor.internals.TaskAndAction.createRemoveTask;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.mock;
+
+class TaskAndActionTest {
+
+    @Test
+    public void shouldCreateAddTaskAction() {
+        final StreamTask task = mock(StreamTask.class);
+
+        final TaskAndAction addTask = createAddTask(task);
+
+        assertEquals(ADD, addTask.getAction());
+        assertEquals(task, addTask.getTask());
+        final Exception exception = assertThrows(IllegalStateException.class, 
addTask::getTaskId);
+        assertEquals("Action type ADD cannot have a task ID!", 
exception.getMessage());
+    }
+
+    @Test
+    public void shouldCreateRemoveTaskAction() {
+        final TaskId taskId = new TaskId(0, 0);
+
+        final TaskAndAction removeTask = createRemoveTask(taskId);
+
+        assertEquals(REMOVE, removeTask.getAction());
+        assertEquals(taskId, removeTask.getTaskId());
+        final Exception exception = assertThrows(IllegalStateException.class, 
removeTask::getTask);
+        assertEquals("Action type REMOVE cannot have a task!", 
exception.getMessage());
+    }
+
+    @Test
+    public void shouldThrowIfAddTaskActionIsCreatedWithNullTask() {
+        final Exception exception = assertThrows(NullPointerException.class, 
() -> createAddTask(null));
+        assertTrue(exception.getMessage().contains("Task to add is null!"));
+    }
+
+    @Test
+    public void shouldThrowIfRemoveTaskActionIsCreatedWithNullTaskId() {
+        final Exception exception = assertThrows(NullPointerException.class, 
() -> createRemoveTask(null));
+        assertTrue(exception.getMessage().contains("Task ID of task to remove 
is null!"));
+    }
+}
\ No newline at end of file

Reply via email to