This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 1ceaf30039 KAFKA-10199: Expose tasks in state updater (#12312)
1ceaf30039 is described below

commit 1ceaf30039e48199e40951c5a8d52894bb45e4d3
Author: Bruno Cadonna <cado...@apache.org>
AuthorDate: Fri Jun 24 18:33:24 2022 +0200

    KAFKA-10199: Expose tasks in state updater (#12312)
    
    This PR exposes the tasks managed by the state updater. The state updater 
manages all tasks that were added to the state updater and that have not yet 
been removed from it by draining one of the output queues.
    
    Reviewers: Guozhang Wang <wangg...@gmail.com>
---
 .../processor/internals/DefaultStateUpdater.java   | 149 ++++++---
 .../streams/processor/internals/StateUpdater.java  |  61 +++-
 .../internals/DefaultStateUpdaterTest.java         | 332 +++++++++++++++++----
 3 files changed, 436 insertions(+), 106 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
index cc580a3b38..0e84574c5c 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
@@ -47,7 +47,9 @@ import java.util.concurrent.locks.Condition;
 import java.util.concurrent.locks.Lock;
 import java.util.concurrent.locks.ReentrantLock;
 import java.util.function.Consumer;
+import java.util.function.Supplier;
 import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 public class DefaultStateUpdater implements StateUpdater {
 
@@ -86,7 +88,7 @@ public class DefaultStateUpdater implements StateUpdater {
         }
 
         public boolean onlyStandbyTasksLeft() {
-            return !updatingTasks.isEmpty() && 
updatingTasks.values().stream().noneMatch(Task::isActive);
+            return !updatingTasks.isEmpty() && 
updatingTasks.values().stream().allMatch(t -> !t.isActive());
         }
 
         @Override
@@ -152,9 +154,7 @@ public class DefaultStateUpdater implements StateUpdater {
 
         private void handleRuntimeException(final RuntimeException 
runtimeException) {
             log.error("An unexpected error occurred within the state updater 
thread: " + runtimeException);
-            final ExceptionAndTasks exceptionAndTasks = new 
ExceptionAndTasks(new HashSet<>(updatingTasks.values()), runtimeException);
-            updatingTasks.clear();
-            exceptionsAndFailedTasks.add(exceptionAndTasks);
+            addToExceptionsAndFailedTasksThenClearUpdatingTasks(new 
ExceptionAndTasks(new HashSet<>(updatingTasks.values()), runtimeException));
             isRunning.set(false);
         }
 
@@ -163,41 +163,51 @@ public class DefaultStateUpdater implements StateUpdater {
             final Set<TaskId> corruptedTaskIds = 
taskCorruptedException.corruptedTasks();
             final Set<Task> corruptedTasks = new HashSet<>();
             for (final TaskId taskId : corruptedTaskIds) {
-                final Task corruptedTask = updatingTasks.remove(taskId);
+                final Task corruptedTask = updatingTasks.get(taskId);
                 if (corruptedTask == null) {
                     throw new IllegalStateException("Task " + taskId + " is 
corrupted but is not updating. " + BUG_ERROR_MESSAGE);
                 }
                 corruptedTasks.add(corruptedTask);
             }
-            exceptionsAndFailedTasks.add(new ExceptionAndTasks(corruptedTasks, 
taskCorruptedException));
+            addToExceptionsAndFailedTasksThenRemoveFromUpdatingTasks(new 
ExceptionAndTasks(corruptedTasks, taskCorruptedException));
         }
 
         private void handleStreamsException(final StreamsException 
streamsException) {
             log.info("Encountered streams exception: ", streamsException);
-            final ExceptionAndTasks exceptionAndTasks;
             if (streamsException.taskId().isPresent()) {
-                exceptionAndTasks = 
handleStreamsExceptionWithTask(streamsException);
+                handleStreamsExceptionWithTask(streamsException);
             } else {
-                exceptionAndTasks = 
handleStreamsExceptionWithoutTask(streamsException);
+                handleStreamsExceptionWithoutTask(streamsException);
             }
-            exceptionsAndFailedTasks.add(exceptionAndTasks);
         }
 
-        private ExceptionAndTasks handleStreamsExceptionWithTask(final 
StreamsException streamsException) {
+        private void handleStreamsExceptionWithTask(final StreamsException 
streamsException) {
             final TaskId failedTaskId = streamsException.taskId().get();
             if (!updatingTasks.containsKey(failedTaskId)) {
                 throw new IllegalStateException("Task " + failedTaskId + " 
failed but is not updating. " + BUG_ERROR_MESSAGE);
             }
             final Set<Task> failedTask = new HashSet<>();
             failedTask.add(updatingTasks.get(failedTaskId));
-            updatingTasks.remove(failedTaskId);
-            return new ExceptionAndTasks(failedTask, streamsException);
+            addToExceptionsAndFailedTasksThenRemoveFromUpdatingTasks(new 
ExceptionAndTasks(failedTask, streamsException));
+        }
+
+        private void handleStreamsExceptionWithoutTask(final StreamsException 
streamsException) {
+            addToExceptionsAndFailedTasksThenClearUpdatingTasks(
+                new ExceptionAndTasks(new HashSet<>(updatingTasks.values()), 
streamsException));
+        }
+
+        // It is important to remove the corrupted tasks from the updating 
tasks after they were added to the
+        // failed tasks.
+        // This ensures that all tasks are found in 
DefaultStateUpdater#getTasks().
+        private void 
addToExceptionsAndFailedTasksThenRemoveFromUpdatingTasks(final 
ExceptionAndTasks exceptionAndTasks) {
+            exceptionsAndFailedTasks.add(exceptionAndTasks);
+            
exceptionAndTasks.getTasks().stream().map(Task::id).forEach(updatingTasks::remove);
+            transitToUpdateStandbysIfOnlyStandbysLeft();
         }
 
-        private ExceptionAndTasks handleStreamsExceptionWithoutTask(final 
StreamsException streamsException) {
-            final ExceptionAndTasks exceptionAndTasks = new 
ExceptionAndTasks(new HashSet<>(updatingTasks.values()), streamsException);
+        private void addToExceptionsAndFailedTasksThenClearUpdatingTasks(final 
ExceptionAndTasks exceptionAndTasks) {
+            exceptionsAndFailedTasks.add(exceptionAndTasks);
             updatingTasks.clear();
-            return exceptionAndTasks;
         }
 
         private void waitIfAllChangelogsCompletelyRead() throws 
InterruptedException {
@@ -235,7 +245,7 @@ public class DefaultStateUpdater implements StateUpdater {
 
         private void addTask(final Task task) {
             if (isStateless(task)) {
-                addTaskToRestoredTasks((StreamTask) task);
+                addToRestoredTasks((StreamTask) task);
                 log.debug("Stateless active task " + task.id() + " was added 
to the restored tasks of the state updater");
             } else {
                 updatingTasks.put(task.id(), task);
@@ -252,13 +262,15 @@ public class DefaultStateUpdater implements StateUpdater {
         }
 
         private void removeTask(final TaskId taskId) {
-            final Task task = updatingTasks.remove(taskId);
+            final Task task = updatingTasks.get(taskId);
             if (task != null) {
                 task.maybeCheckpoint(true);
 
                 final Collection<TopicPartition> changelogPartitions = 
task.changelogPartitions();
                 changelogReader.unregister(changelogPartitions);
                 removedTasks.add(task);
+                updatingTasks.remove(taskId);
+                transitToUpdateStandbysIfOnlyStandbysLeft();
                 log.debug((task.isActive() ? "Active" : "Standby")
                     + " task " + task.id() + " was removed from the updating 
tasks and added to the removed tasks.");
             } else {
@@ -276,16 +288,20 @@ public class DefaultStateUpdater implements StateUpdater {
             if (restoredChangelogs.containsAll(taskChangelogPartitions)) {
                 task.completeRestoration(offsetResetter);
                 task.maybeCheckpoint(true);
-                addTaskToRestoredTasks(task);
+                addToRestoredTasks(task);
                 updatingTasks.remove(task.id());
                 log.debug("Stateful active task " + task.id() + " completed 
restoration");
-                if (onlyStandbyTasksLeft()) {
-                    changelogReader.transitToUpdateStandby();
-                }
+                transitToUpdateStandbysIfOnlyStandbysLeft();
+            }
+        }
+
+        private void transitToUpdateStandbysIfOnlyStandbysLeft() {
+            if (onlyStandbyTasksLeft()) {
+                changelogReader.transitToUpdateStandby();
             }
         }
 
-        private void addTaskToRestoredTasks(final StreamTask task) {
+        private void addToRestoredTasks(final StreamTask task) {
             restoredActiveTasksLock.lock();
             try {
                 restoredActiveTasks.add(task);
@@ -325,12 +341,12 @@ public class DefaultStateUpdater implements StateUpdater {
     private final Condition restoredActiveTasksCondition = 
restoredActiveTasksLock.newCondition();
     private final BlockingQueue<ExceptionAndTasks> exceptionsAndFailedTasks = 
new LinkedBlockingQueue<>();
     private final BlockingQueue<Task> removedTasks = new 
LinkedBlockingQueue<>();
-    private CountDownLatch shutdownGate;
 
     private final long commitIntervalMs;
     private long lastCommitMs;
 
     private StateUpdaterThread stateUpdaterThread = null;
+    private CountDownLatch shutdownGate;
 
     public DefaultStateUpdater(final StreamsConfig config,
                                final ChangelogReader changelogReader,
@@ -339,20 +355,36 @@ public class DefaultStateUpdater implements StateUpdater {
         this.changelogReader = changelogReader;
         this.offsetResetter = offsetResetter;
         this.time = time;
-
         this.commitIntervalMs = 
config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG);
         // initialize the last commit as of now to prevent first commit 
happens immediately
         this.lastCommitMs = time.milliseconds();
     }
 
-    @Override
-    public void add(final Task task) {
+    public void start() {
         if (stateUpdaterThread == null) {
             stateUpdaterThread = new StateUpdaterThread("state-updater", 
changelogReader, offsetResetter);
             stateUpdaterThread.start();
             shutdownGate = new CountDownLatch(1);
         }
+    }
 
+    @Override
+    public void shutdown(final Duration timeout) {
+        if (stateUpdaterThread != null) {
+            stateUpdaterThread.isRunning.set(false);
+            stateUpdaterThread.interrupt();
+            try {
+                if (!shutdownGate.await(timeout.toMillis(), 
TimeUnit.MILLISECONDS)) {
+                    throw new StreamsException("State updater thread did not 
shutdown within the timeout");
+                }
+                stateUpdaterThread = null;
+            } catch (final InterruptedException ignored) {
+            }
+        }
+    }
+
+    @Override
+    public void add(final Task task) {
         verifyStateFor(task);
 
         tasksAndActionsLock.lock();
@@ -407,8 +439,7 @@ public class DefaultStateUpdater implements StateUpdater {
                 now = time.milliseconds();
             }
             return result;
-        } catch (final InterruptedException e) {
-            // ignore
+        } catch (final InterruptedException ignored) {
         }
         return result;
     }
@@ -428,11 +459,15 @@ public class DefaultStateUpdater implements StateUpdater {
     }
 
     public Set<StandbyTask> getUpdatingStandbyTasks() {
-        return Collections.unmodifiableSet(new 
HashSet<>(stateUpdaterThread.getUpdatingStandbyTasks()));
+        return stateUpdaterThread != null
+            ? Collections.unmodifiableSet(new 
HashSet<>(stateUpdaterThread.getUpdatingStandbyTasks()))
+            : Collections.emptySet();
     }
 
     public Set<Task> getUpdatingTasks() {
-        return Collections.unmodifiableSet(new 
HashSet<>(stateUpdaterThread.getUpdatingTasks()));
+        return stateUpdaterThread != null
+            ? Collections.unmodifiableSet(new 
HashSet<>(stateUpdaterThread.getUpdatingTasks()))
+            : Collections.emptySet();
     }
 
     public Set<StreamTask> getRestoredActiveTasks() {
@@ -453,17 +488,47 @@ public class DefaultStateUpdater implements StateUpdater {
     }
 
     @Override
-    public void shutdown(final Duration timeout) {
-        if (stateUpdaterThread != null) {
-            stateUpdaterThread.isRunning.set(false);
-            stateUpdaterThread.interrupt();
-            try {
-                if (!shutdownGate.await(timeout.toMillis(), 
TimeUnit.MILLISECONDS)) {
-                    throw new StreamsException("State updater thread did not 
shutdown within the timeout");
-                }
-                stateUpdaterThread = null;
-            } catch (final InterruptedException ignored) {
-            }
+    public Set<Task> getTasks() {
+        return executeWithQueuesLocked(() -> 
getStreamOfTasks().collect(Collectors.toSet()));
+    }
+
+    @Override
+    public Set<StreamTask> getActiveTasks() {
+        return executeWithQueuesLocked(
+            () -> getStreamOfTasks().filter(Task::isActive).map(t -> 
(StreamTask) t).collect(Collectors.toSet())
+        );
+    }
+
+    @Override
+    public Set<StandbyTask> getStandbyTasks() {
+        return executeWithQueuesLocked(
+            () -> getStreamOfTasks().filter(t -> !t.isActive()).map(t -> 
(StandbyTask) t).collect(Collectors.toSet())
+        );
+    }
+
+    private <T> Set<T> executeWithQueuesLocked(final Supplier<Set<T>> action) {
+        tasksAndActionsLock.lock();
+        restoredActiveTasksLock.lock();
+        try {
+            return action.get();
+        } finally {
+            restoredActiveTasksLock.unlock();
+            tasksAndActionsLock.unlock();
         }
     }
+
+    private Stream<Task> getStreamOfTasks() {
+        return
+            Stream.concat(
+                tasksAndActions.stream()
+                    .filter(taskAndAction -> taskAndAction.getAction() == 
Action.ADD)
+                    .map(TaskAndAction::getTask),
+                Stream.concat(
+                    getUpdatingTasks().stream(),
+                    Stream.concat(
+                        restoredActiveTasks.stream(),
+                        Stream.concat(
+                            
exceptionsAndFailedTasks.stream().flatMap(exceptionAndTasks -> 
exceptionAndTasks.getTasks().stream()),
+                            removedTasks.stream()))));
+    }
 }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
index 42e65d4adb..1b229bc818 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
@@ -35,7 +35,7 @@ public interface StateUpdater {
             this.exception = Objects.requireNonNull(exception);
         }
 
-        public Set<Task> tasks() {
+        public Set<Task> getTasks() {
             return Collections.unmodifiableSet(tasks);
         }
 
@@ -57,6 +57,21 @@ public interface StateUpdater {
         }
     }
 
+    /**
+     * Starts the state updater.
+     */
+    void start();
+
+    /**
+     * Shuts down the state updater.
+     *
+     * @param timeout duration how long to wait until the state updater is 
shut down
+     *
+     * @throws
+     *     org.apache.kafka.streams.errors.StreamsException if the state 
updater thread cannot shutdown within the timeout
+     */
+    void shutdown(final Duration timeout);
+
     /**
      * Adds a task (active or standby) to the state updater.
      *
@@ -113,12 +128,46 @@ public interface StateUpdater {
     List<ExceptionAndTasks> drainExceptionsAndFailedTasks();
 
     /**
-     * Shuts down the state updater.
+     * Gets all tasks that are managed by the state updater.
      *
-     * @param timeout duration how long to wait until the state updater is 
shut down
+     * The state updater manages all tasks that were added with the {@link 
StateUpdater#add(Task)} and that have
+     * not been removed from the state updater with one of the following 
methods:
+     * <ul>
+     *   <li>{@link StateUpdater#drainRestoredActiveTasks(Duration)}</li>
+     *   <li>{@link StateUpdater#drainRemovedTasks()}</li>
+     *   <li>{@link StateUpdater#drainExceptionsAndFailedTasks()}</li>
+     * </ul>
      *
-     * @throws
-     *     org.apache.kafka.streams.errors.StreamsException if the state 
updater thread cannot shutdown within the timeout
+     * @return set of all tasks managed by the state updater
      */
-    void shutdown(final Duration timeout);
+    Set<Task> getTasks();
+
+    /**
+     * Gets active tasks that are managed by the state updater.
+     *
+     * The state updater manages all active tasks that were added with the 
{@link StateUpdater#add(Task)} and that have
+     * not been removed from the state updater with one of the following 
methods:
+     * <ul>
+     *   <li>{@link StateUpdater#drainRestoredActiveTasks(Duration)}</li>
+     *   <li>{@link StateUpdater#drainRemovedTasks()}</li>
+     *   <li>{@link StateUpdater#drainExceptionsAndFailedTasks()}</li>
+     * </ul>
+     *
+     * @return set of all tasks managed by the state updater
+     */
+    Set<StreamTask> getActiveTasks();
+
+    /**
+     * Gets standby tasks that are managed by the state updater.
+     *
+     * The state updater manages all standby tasks that were added with the 
{@link StateUpdater#add(Task)} and that have
+     * not been removed from the state updater with one of the following 
methods:
+     * <ul>
+     *   <li>{@link StateUpdater#drainRemovedTasks()}</li>
+     *   <li>{@link StateUpdater#drainExceptionsAndFailedTasks()}</li>
+     * </ul>
+     *
+     * @return set of all tasks managed by the state updater
+     */
+    Set<StandbyTask> getStandbyTasks();
 }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
index 8f0fc935a8..8bd81828f6 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
@@ -69,10 +69,12 @@ class DefaultStateUpdaterTest {
     private final static long CALL_TIMEOUT = 1000;
     private final static long VERIFICATION_TIMEOUT = 15000;
     private final static TopicPartition TOPIC_PARTITION_A_0 = new 
TopicPartition("topicA", 0);
+    private final static TopicPartition TOPIC_PARTITION_A_1 = new 
TopicPartition("topicA", 1);
     private final static TopicPartition TOPIC_PARTITION_B_0 = new 
TopicPartition("topicB", 0);
     private final static TopicPartition TOPIC_PARTITION_C_0 = new 
TopicPartition("topicC", 0);
     private final static TopicPartition TOPIC_PARTITION_D_0 = new 
TopicPartition("topicD", 0);
     private final static TaskId TASK_0_0 = new TaskId(0, 0);
+    private final static TaskId TASK_0_1 = new TaskId(0, 1);
     private final static TaskId TASK_0_2 = new TaskId(0, 2);
     private final static TaskId TASK_1_0 = new TaskId(1, 0);
     private final static TaskId TASK_1_1 = new TaskId(1, 1);
@@ -100,8 +102,7 @@ class DefaultStateUpdaterTest {
 
     @Test
     public void shouldShutdownStateUpdater() {
-        final StreamTask task = createStatelessTaskInStateRestoring(TASK_0_0);
-        stateUpdater.add(task);
+        stateUpdater.start();
 
         stateUpdater.shutdown(Duration.ofMinutes(1));
 
@@ -110,13 +111,11 @@ class DefaultStateUpdaterTest {
 
     @Test
     public void shouldShutdownStateUpdaterAndRestart() {
-        final StreamTask task1 = createStatelessTaskInStateRestoring(TASK_0_0);
-        stateUpdater.add(task1);
+        stateUpdater.start();
 
         stateUpdater.shutdown(Duration.ofMinutes(1));
 
-        final StreamTask task2 = createStatelessTaskInStateRestoring(TASK_1_0);
-        stateUpdater.add(task2);
+        stateUpdater.start();
 
         stateUpdater.shutdown(Duration.ofMinutes(1));
 
@@ -167,6 +166,7 @@ class DefaultStateUpdaterTest {
     }
 
     private void shouldImmediatelyAddStatelessTasksToRestoredTasks(final 
StreamTask... tasks) throws Exception {
+        stateUpdater.start();
         for (final StreamTask task : tasks) {
             stateUpdater.add(task);
         }
@@ -190,6 +190,7 @@ class DefaultStateUpdaterTest {
             .thenReturn(false)
             .thenReturn(false)
             .thenReturn(true);
+        stateUpdater.start();
 
         stateUpdater.add(task);
 
@@ -219,6 +220,7 @@ class DefaultStateUpdaterTest {
             .thenReturn(false)
             .thenReturn(false)
             .thenReturn(true);
+        stateUpdater.start();
 
         stateUpdater.add(task1);
         stateUpdater.add(task2);
@@ -242,6 +244,7 @@ class DefaultStateUpdaterTest {
         
assertTrue(stateUpdater.drainRestoredActiveTasks(Duration.ZERO).isEmpty());
 
         final StreamTask task1 = createStatelessTaskInStateRestoring(TASK_0_0);
+        stateUpdater.start();
         stateUpdater.add(task1);
 
         verifyDrainingRestoredActiveTasks(task1);
@@ -276,6 +279,7 @@ class DefaultStateUpdaterTest {
     private void shouldUpdateStandbyTasks(final StandbyTask... tasks) throws 
Exception {
         
when(changelogReader.completedChangelogs()).thenReturn(Collections.emptySet());
         when(changelogReader.allChangelogsCompleted()).thenReturn(false);
+        stateUpdater.start();
 
         for (final StandbyTask task : tasks) {
             stateUpdater.add(task);
@@ -300,8 +304,8 @@ class DefaultStateUpdaterTest {
             .thenReturn(Collections.emptySet())
             .thenReturn(mkSet(TOPIC_PARTITION_A_0))
             .thenReturn(mkSet(TOPIC_PARTITION_A_0, TOPIC_PARTITION_B_0));
-        when(changelogReader.allChangelogsCompleted())
-            .thenReturn(false);
+        when(changelogReader.allChangelogsCompleted()).thenReturn(false);
+        stateUpdater.start();
 
         stateUpdater.add(task1);
         stateUpdater.add(task2);
@@ -330,8 +334,8 @@ class DefaultStateUpdaterTest {
             .thenReturn(Collections.emptySet())
             .thenReturn(mkSet(TOPIC_PARTITION_A_0))
             .thenReturn(mkSet(TOPIC_PARTITION_B_0));
-        when(changelogReader.allChangelogsCompleted())
-            .thenReturn(false);
+        when(changelogReader.allChangelogsCompleted()).thenReturn(false);
+        stateUpdater.start();
 
         stateUpdater.add(task1);
         stateUpdater.add(task2);
@@ -352,6 +356,56 @@ class DefaultStateUpdaterTest {
         orderVerifier.verify(changelogReader, 
times(1)).transitToUpdateStandby();
     }
 
+    @Test
+    public void shouldUpdateStandbyTaskAfterAllActiveStatefulTasksFailed() 
throws Exception {
+        final StreamTask activeTask1 = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
+        final StreamTask activeTask2 = 
createActiveStatefulTaskInStateRestoring(TASK_0_1, 
Collections.singletonList(TOPIC_PARTITION_B_0));
+        final StandbyTask standbyTask = 
createStandbyTaskInStateRunning(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_C_0));
+        final TaskCorruptedException taskCorruptedException =
+            new TaskCorruptedException(mkSet(activeTask1.id(), 
activeTask2.id()));
+        final Map<TaskId, Task> updatingTasks1 = mkMap(
+            mkEntry(activeTask1.id(), activeTask1),
+            mkEntry(activeTask2.id(), activeTask2),
+            mkEntry(standbyTask.id(), standbyTask)
+        );
+        
doThrow(taskCorruptedException).doNothing().when(changelogReader).restore(updatingTasks1);
+        when(changelogReader.allChangelogsCompleted()).thenReturn(false);
+        stateUpdater.start();
+
+        stateUpdater.add(activeTask1);
+        stateUpdater.add(activeTask2);
+        stateUpdater.add(standbyTask);
+
+        final ExceptionAndTasks expectedExceptionAndTasks =
+            new ExceptionAndTasks(mkSet(activeTask1, activeTask2), 
taskCorruptedException);
+        verifyExceptionsAndFailedTasks(expectedExceptionAndTasks);
+        final InOrder orderVerifier = inOrder(changelogReader);
+        orderVerifier.verify(changelogReader, 
atLeast(1)).enforceRestoreActive();
+        orderVerifier.verify(changelogReader, 
times(1)).transitToUpdateStandby();
+    }
+
+    @Test
+    public void shouldUpdateStandbyTaskAfterAllActiveStatefulTasksRemoved() 
throws Exception {
+        final StreamTask activeTask1 = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
+        final StreamTask activeTask2 = 
createActiveStatefulTaskInStateRestoring(TASK_0_1, 
Collections.singletonList(TOPIC_PARTITION_B_0));
+        final StandbyTask standbyTask = 
createStandbyTaskInStateRunning(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_C_0));
+        
when(changelogReader.completedChangelogs()).thenReturn(Collections.emptySet());
+        when(changelogReader.allChangelogsCompleted()).thenReturn(false);
+        stateUpdater.start();
+        stateUpdater.add(activeTask1);
+        stateUpdater.add(activeTask2);
+        stateUpdater.add(standbyTask);
+        verifyUpdatingTasks(activeTask1, activeTask2, standbyTask);
+
+        stateUpdater.remove(activeTask1.id());
+        stateUpdater.remove(activeTask2.id());
+
+        verifyRemovedTasks(activeTask1, activeTask2);
+        final InOrder orderVerifier = inOrder(changelogReader);
+        orderVerifier.verify(changelogReader, 
atLeast(1)).enforceRestoreActive();
+        orderVerifier.verify(changelogReader, 
times(1)).transitToUpdateStandby();
+    }
+
     @Test
     public void shouldRemoveActiveStatefulTask() throws Exception {
         final StreamTask task = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
@@ -365,10 +419,9 @@ class DefaultStateUpdaterTest {
     }
 
     private void shouldRemoveStatefulTask(final Task task) throws Exception {
-        when(changelogReader.completedChangelogs())
-            .thenReturn(Collections.emptySet());
-        when(changelogReader.allChangelogsCompleted())
-            .thenReturn(false);
+        
when(changelogReader.completedChangelogs()).thenReturn(Collections.emptySet());
+        when(changelogReader.allChangelogsCompleted()).thenReturn(false);
+        stateUpdater.start();
         stateUpdater.add(task);
 
         stateUpdater.remove(task.id());
@@ -395,10 +448,9 @@ class DefaultStateUpdaterTest {
 
     private void shouldNotRemoveTaskFromRestoredActiveTasks(final StreamTask 
task, final Set<TopicPartition> completedChangelogs) throws Exception {
         final StreamTask controlTask = 
createActiveStatefulTaskInStateRestoring(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_B_0));
-        when(changelogReader.completedChangelogs())
-            .thenReturn(completedChangelogs);
-        when(changelogReader.allChangelogsCompleted())
-            .thenReturn(false);
+        
when(changelogReader.completedChangelogs()).thenReturn(Collections.singleton(TOPIC_PARTITION_A_0));
+        when(changelogReader.allChangelogsCompleted()).thenReturn(false);
+        stateUpdater.start();
         stateUpdater.add(task);
         stateUpdater.add(controlTask);
         verifyRestoredActiveTasks(task);
@@ -427,14 +479,17 @@ class DefaultStateUpdaterTest {
     private void shouldNotRemoveTaskFromFailedTasks(final Task task) throws 
Exception {
         final StreamTask controlTask = 
createActiveStatefulTaskInStateRestoring(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_B_0));
         final StreamsException streamsException = new 
StreamsException("Something happened", task.id());
-        when(changelogReader.completedChangelogs())
-            .thenReturn(Collections.emptySet());
-        when(changelogReader.allChangelogsCompleted())
-            .thenReturn(false);
-        doNothing()
-            .doThrow(streamsException)
+        
when(changelogReader.completedChangelogs()).thenReturn(Collections.emptySet());
+        when(changelogReader.allChangelogsCompleted()).thenReturn(false);
+        final Map<TaskId, Task> updatingTasks = mkMap(
+            mkEntry(task.id(), task),
+            mkEntry(controlTask.id(), controlTask)
+        );
+        doThrow(streamsException)
             .doNothing()
-            .when(changelogReader).restore(anyMap());
+            .when(changelogReader).restore(updatingTasks);
+        stateUpdater.start();
+
         stateUpdater.add(task);
         stateUpdater.add(controlTask);
         final ExceptionAndTasks expectedExceptionAndTasks = new 
ExceptionAndTasks(mkSet(task), streamsException);
@@ -452,10 +507,9 @@ class DefaultStateUpdaterTest {
     @Test
     public void shouldDrainRemovedTasks() throws Exception {
         assertTrue(stateUpdater.drainRemovedTasks().isEmpty());
-        when(changelogReader.completedChangelogs())
-            .thenReturn(Collections.emptySet());
-        when(changelogReader.allChangelogsCompleted())
-            .thenReturn(false);
+        
when(changelogReader.completedChangelogs()).thenReturn(Collections.emptySet());
+        when(changelogReader.allChangelogsCompleted()).thenReturn(false);
+        stateUpdater.start();
 
         final StreamTask task1 = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_B_0));
         stateUpdater.add(task1);
@@ -487,6 +541,7 @@ class DefaultStateUpdaterTest {
             mkEntry(task2.id(), task2)
         );
         
doNothing().doThrow(streamsException).when(changelogReader).restore(updatingTasks);
+        stateUpdater.start();
 
         stateUpdater.add(task1);
         stateUpdater.add(task2);
@@ -521,6 +576,7 @@ class DefaultStateUpdaterTest {
         doNothing()
             .doThrow(streamsException2)
             .when(changelogReader).restore(updatingTasksBeforeSecondThrow);
+        stateUpdater.start();
 
         stateUpdater.add(task1);
         stateUpdater.add(task2);
@@ -547,6 +603,7 @@ class DefaultStateUpdaterTest {
             mkEntry(task3.id(), task3)
         );
         
doNothing().doThrow(taskCorruptedException).doNothing().when(changelogReader).restore(updatingTasks);
+        stateUpdater.start();
 
         stateUpdater.add(task1);
         stateUpdater.add(task2);
@@ -569,6 +626,7 @@ class DefaultStateUpdaterTest {
             mkEntry(task2.id(), task2)
         );
         
doThrow(illegalStateException).when(changelogReader).restore(updatingTasks);
+        stateUpdater.start();
 
         stateUpdater.add(task1);
         stateUpdater.add(task2);
@@ -613,6 +671,7 @@ class DefaultStateUpdaterTest {
             mkEntry(task4.id(), task4)
         );
         
doThrow(streamsException4).when(changelogReader).restore(updatingTasks4);
+        stateUpdater.start();
 
         stateUpdater.add(task1);
 
@@ -635,11 +694,9 @@ class DefaultStateUpdaterTest {
         final StreamTask task2 = 
createActiveStatefulTaskInStateRestoring(TASK_0_2, 
Collections.singletonList(TOPIC_PARTITION_B_0));
         final StandbyTask task3 = createStandbyTaskInStateRunning(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_C_0));
         final StandbyTask task4 = createStandbyTaskInStateRunning(TASK_1_1, 
Collections.singletonList(TOPIC_PARTITION_D_0));
-        when(changelogReader.completedChangelogs())
-                .thenReturn(Collections.emptySet());
-        when(changelogReader.allChangelogsCompleted())
-                .thenReturn(false);
-
+        
when(changelogReader.completedChangelogs()).thenReturn(Collections.emptySet());
+        when(changelogReader.allChangelogsCompleted()).thenReturn(false);
+        stateUpdater.start();
         stateUpdater.add(task1);
         stateUpdater.add(task2);
         stateUpdater.add(task3);
@@ -662,11 +719,9 @@ class DefaultStateUpdaterTest {
             final StreamTask task2 = 
createActiveStatefulTaskInStateRestoring(TASK_0_2, 
Collections.singletonList(TOPIC_PARTITION_B_0));
             final StandbyTask task3 = 
createStandbyTaskInStateRunning(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_C_0));
             final StandbyTask task4 = 
createStandbyTaskInStateRunning(TASK_1_1, 
Collections.singletonList(TOPIC_PARTITION_D_0));
-            when(changelogReader.completedChangelogs())
-                    .thenReturn(Collections.emptySet());
-            when(changelogReader.allChangelogsCompleted())
-                    .thenReturn(false);
-
+            
when(changelogReader.completedChangelogs()).thenReturn(Collections.emptySet());
+            when(changelogReader.allChangelogsCompleted()).thenReturn(false);
+            stateUpdater.start();
             stateUpdater.add(task1);
             stateUpdater.add(task2);
             stateUpdater.add(task3);
@@ -690,6 +745,164 @@ class DefaultStateUpdaterTest {
         }
     }
 
+    @Test
+    public void shouldGetTasksFromInputQueue() {
+        stateUpdater.shutdown(Duration.ofMillis(Long.MAX_VALUE));
+
+        final StreamTask activeTask1 = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
+        final StreamTask activeTask2 = 
createActiveStatefulTaskInStateRestoring(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_B_0));
+        final StandbyTask standbyTask1 = 
createStandbyTaskInStateRunning(TASK_0_2, 
Collections.singletonList(TOPIC_PARTITION_C_0));
+        final StandbyTask standbyTask2 = 
createStandbyTaskInStateRunning(TASK_1_1, 
Collections.singletonList(TOPIC_PARTITION_D_0));
+        final StandbyTask standbyTask3 = 
createStandbyTaskInStateRunning(TASK_0_1, 
Collections.singletonList(TOPIC_PARTITION_A_1));
+        stateUpdater.add(activeTask1);
+        stateUpdater.add(standbyTask1);
+        stateUpdater.add(standbyTask2);
+        stateUpdater.remove(TASK_0_0);
+        stateUpdater.add(activeTask2);
+        stateUpdater.add(standbyTask3);
+
+        final Set<Task> tasks = stateUpdater.getTasks();
+
+        assertEquals(5, tasks.size());
+        assertTrue(tasks.containsAll(mkSet(activeTask1, activeTask2, 
standbyTask1, standbyTask2, standbyTask3)));
+
+        final Set<StreamTask> activeTasks = stateUpdater.getActiveTasks();
+
+        assertEquals(2, activeTasks.size());
+        assertTrue(activeTasks.containsAll(mkSet(activeTask1, activeTask2)));
+
+        final Set<StandbyTask> standbyTasks = stateUpdater.getStandbyTasks();
+
+        assertEquals(3, standbyTasks.size());
+        assertTrue(standbyTasks.containsAll(mkSet(standbyTask1, standbyTask2, 
standbyTask3)));
+    }
+
+    @Test
+    public void shouldGetTasksFromUpdatingTasks() throws Exception {
+        final StreamTask activeTask1 = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
+        final StreamTask activeTask2 = 
createActiveStatefulTaskInStateRestoring(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_B_0));
+        final StandbyTask standbyTask1 = 
createStandbyTaskInStateRunning(TASK_0_2, 
Collections.singletonList(TOPIC_PARTITION_C_0));
+        final StandbyTask standbyTask2 = 
createStandbyTaskInStateRunning(TASK_1_1, 
Collections.singletonList(TOPIC_PARTITION_D_0));
+        final StandbyTask standbyTask3 = 
createStandbyTaskInStateRunning(TASK_0_1, 
Collections.singletonList(TOPIC_PARTITION_A_1));
+        
when(changelogReader.completedChangelogs()).thenReturn(Collections.emptySet());
+        when(changelogReader.allChangelogsCompleted()).thenReturn(false);
+        stateUpdater.start();
+        stateUpdater.add(activeTask1);
+        stateUpdater.add(standbyTask1);
+        stateUpdater.add(standbyTask2);
+        stateUpdater.add(activeTask2);
+        stateUpdater.add(standbyTask3);
+        verifyUpdatingTasks(activeTask1, activeTask2, standbyTask1, 
standbyTask2, standbyTask3);
+
+        final Set<Task> tasks = stateUpdater.getTasks();
+
+        assertEquals(5, tasks.size());
+        assertTrue(tasks.containsAll(mkSet(activeTask1, activeTask2, 
standbyTask1, standbyTask2, standbyTask3)));
+
+        final Set<StreamTask> activeTasks = stateUpdater.getActiveTasks();
+
+        assertEquals(2, activeTasks.size());
+        assertTrue(activeTasks.containsAll(mkSet(activeTask1, activeTask2)));
+
+        final Set<StandbyTask> standbyTasks = stateUpdater.getStandbyTasks();
+
+        assertEquals(3, standbyTasks.size());
+        assertTrue(standbyTasks.containsAll(mkSet(standbyTask1, standbyTask2, 
standbyTask3)));
+    }
+
+    @Test
+    public void shouldGetTasksFromRestoredActiveTasks() throws Exception {
+        final StreamTask activeTask1 = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
+        final StreamTask activeTask2 = 
createActiveStatefulTaskInStateRestoring(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_B_0));
+        
when(changelogReader.completedChangelogs()).thenReturn(mkSet(TOPIC_PARTITION_A_0,
 TOPIC_PARTITION_B_0));
+        when(changelogReader.allChangelogsCompleted()).thenReturn(false);
+        stateUpdater.start();
+        stateUpdater.add(activeTask1);
+        stateUpdater.add(activeTask2);
+        verifyRestoredActiveTasks(activeTask1, activeTask2);
+
+        verifyGetTasks(mkSet(activeTask1, activeTask2), mkSet());
+
+        stateUpdater.drainRestoredActiveTasks(Duration.ofMinutes(1));
+
+        verifyGetTasks(mkSet(), mkSet());
+    }
+
+    @Test
+    public void shouldGetTasksFromExceptionsAndFailedTasks() throws Exception {
+        final StreamTask activeTask1 = 
createActiveStatefulTaskInStateRestoring(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_B_0));
+        final StandbyTask standbyTask2 = 
createStandbyTaskInStateRunning(TASK_1_1, 
Collections.singletonList(TOPIC_PARTITION_D_0));
+        final StandbyTask standbyTask1 = 
createStandbyTaskInStateRunning(TASK_0_1, 
Collections.singletonList(TOPIC_PARTITION_A_1));
+        final TaskCorruptedException taskCorruptedException =
+            new TaskCorruptedException(mkSet(standbyTask1.id(), 
standbyTask2.id()));
+        final StreamsException streamsException = new StreamsException("The 
Streams were crossed!", activeTask1.id());
+        final Map<TaskId, Task> updatingTasks1 = mkMap(
+            mkEntry(activeTask1.id(), activeTask1),
+            mkEntry(standbyTask1.id(), standbyTask1),
+            mkEntry(standbyTask2.id(), standbyTask2)
+        );
+        
doNothing().doThrow(taskCorruptedException).doNothing().when(changelogReader).restore(updatingTasks1);
+        final Map<TaskId, Task> updatingTasks2 = mkMap(
+            mkEntry(activeTask1.id(), activeTask1)
+        );
+        
doNothing().doThrow(streamsException).doNothing().when(changelogReader).restore(updatingTasks2);
+        stateUpdater.start();
+        stateUpdater.add(standbyTask1);
+        stateUpdater.add(activeTask1);
+        stateUpdater.add(standbyTask2);
+        final ExceptionAndTasks expectedExceptionAndTasks1 =
+            new ExceptionAndTasks(mkSet(standbyTask1, standbyTask2), 
taskCorruptedException);
+        final ExceptionAndTasks expectedExceptionAndTasks2 = new 
ExceptionAndTasks(mkSet(activeTask1), streamsException);
+        verifyExceptionsAndFailedTasks(expectedExceptionAndTasks1, 
expectedExceptionAndTasks2);
+
+        verifyGetTasks(mkSet(activeTask1), mkSet(standbyTask1, standbyTask2));
+
+        stateUpdater.drainExceptionsAndFailedTasks();
+
+        verifyGetTasks(mkSet(), mkSet());
+    }
+
+    @Test
+    public void shouldGetTasksFromRemovedTasks() throws Exception {
+        final StreamTask activeTask = 
createActiveStatefulTaskInStateRestoring(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_B_0));
+        final StandbyTask standbyTask2 = 
createStandbyTaskInStateRunning(TASK_1_1, 
Collections.singletonList(TOPIC_PARTITION_D_0));
+        final StandbyTask standbyTask1 = 
createStandbyTaskInStateRunning(TASK_0_1, 
Collections.singletonList(TOPIC_PARTITION_A_1));
+        
when(changelogReader.completedChangelogs()).thenReturn(Collections.emptySet());
+        when(changelogReader.allChangelogsCompleted()).thenReturn(false);
+        stateUpdater.start();
+        stateUpdater.add(standbyTask1);
+        stateUpdater.add(activeTask);
+        stateUpdater.add(standbyTask2);
+        stateUpdater.remove(standbyTask1.id());
+        stateUpdater.remove(standbyTask2.id());
+        stateUpdater.remove(activeTask.id());
+        verifyRemovedTasks(activeTask, standbyTask1, standbyTask2);
+
+        verifyGetTasks(mkSet(activeTask), mkSet(standbyTask1, standbyTask2));
+
+        stateUpdater.drainRemovedTasks();
+
+        verifyGetTasks(mkSet(), mkSet());
+    }
+
+    private void verifyGetTasks(final Set<StreamTask> expectedActiveTasks,
+                                final Set<StandbyTask> expectedStandbyTasks) {
+        final Set<Task> tasks = stateUpdater.getTasks();
+
+        final Set<Task> expectedTasks = new HashSet<>(expectedActiveTasks);
+        expectedTasks.addAll(expectedStandbyTasks);
+        assertEquals(expectedActiveTasks.size() + expectedStandbyTasks.size(), 
tasks.size());
+        assertTrue(tasks.containsAll(expectedTasks));
+
+        final Set<StreamTask> activeTasks = stateUpdater.getActiveTasks();
+        assertEquals(expectedActiveTasks.size(), activeTasks.size());
+        assertTrue(activeTasks.containsAll(expectedActiveTasks));
+
+        final Set<StandbyTask> standbyTasks = stateUpdater.getStandbyTasks();
+        assertEquals(expectedStandbyTasks.size(), standbyTasks.size());
+        assertTrue(standbyTasks.containsAll(expectedStandbyTasks));
+    }
+
     private void verifyRestoredActiveTasks(final StreamTask... tasks) throws 
Exception {
         if (tasks.length == 0) {
             assertTrue(stateUpdater.getRestoredActiveTasks().isEmpty());
@@ -699,12 +912,12 @@ class DefaultStateUpdaterTest {
             waitForCondition(
                 () -> {
                     
restoredTasks.addAll(stateUpdater.getRestoredActiveTasks());
-                    return restoredTasks.containsAll(expectedRestoredTasks);
+                    return restoredTasks.containsAll(expectedRestoredTasks)
+                        && restoredTasks.size() == 
expectedRestoredTasks.size();
                 },
                 VERIFICATION_TIMEOUT,
                 "Did not get all restored active task within the given 
timeout!"
             );
-            assertEquals(expectedRestoredTasks.size(), restoredTasks.size());
             assertTrue(restoredTasks.stream().allMatch(task -> task.state() == 
State.RESTORING));
         }
     }
@@ -715,12 +928,12 @@ class DefaultStateUpdaterTest {
         waitForCondition(
             () -> {
                 
restoredTasks.addAll(stateUpdater.drainRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT)));
-                return restoredTasks.containsAll(expectedRestoredTasks);
+                return restoredTasks.containsAll(expectedRestoredTasks)
+                    && restoredTasks.size() == expectedRestoredTasks.size();
             },
             VERIFICATION_TIMEOUT,
             "Did not get all restored active task within the given timeout!"
         );
-        assertEquals(expectedRestoredTasks.size(), restoredTasks.size());
         
assertTrue(stateUpdater.drainRestoredActiveTasks(Duration.ZERO).isEmpty());
     }
 
@@ -733,13 +946,16 @@ class DefaultStateUpdaterTest {
             waitForCondition(
                 () -> {
                     updatingTasks.addAll(stateUpdater.getUpdatingTasks());
-                    return updatingTasks.containsAll(expectedUpdatingTasks);
+                    return updatingTasks.containsAll(expectedUpdatingTasks)
+                        && updatingTasks.size() == 
expectedUpdatingTasks.size();
                 },
                 VERIFICATION_TIMEOUT,
                 "Did not get all updating task within the given timeout!"
             );
-            assertEquals(expectedUpdatingTasks.size(), updatingTasks.size());
-            assertTrue(updatingTasks.stream().allMatch(task -> task.state() == 
State.RESTORING));
+            assertTrue(updatingTasks.stream()
+                .allMatch(task -> task.isActive() && task.state() == 
State.RESTORING
+                    ||
+                    !task.isActive() && task.state() == State.RUNNING));
         }
     }
 
@@ -749,12 +965,12 @@ class DefaultStateUpdaterTest {
         waitForCondition(
             () -> {
                 standbyTasks.addAll(stateUpdater.getUpdatingStandbyTasks());
-                return standbyTasks.containsAll(expectedStandbyTasks);
+                return standbyTasks.containsAll(expectedStandbyTasks)
+                    && standbyTasks.size() == expectedStandbyTasks.size();
             },
             VERIFICATION_TIMEOUT,
             "Did not see all standby task within the given timeout!"
         );
-        assertEquals(expectedStandbyTasks.size(), standbyTasks.size());
         assertTrue(standbyTasks.stream().allMatch(task -> task.state() == 
State.RUNNING));
     }
 
@@ -767,15 +983,15 @@ class DefaultStateUpdaterTest {
             waitForCondition(
                 () -> {
                     removedTasks.addAll(stateUpdater.getRemovedTasks());
-                    return removedTasks.containsAll(mkSet(tasks));
+                    return removedTasks.containsAll(expectedRemovedTasks)
+                        && removedTasks.size() == expectedRemovedTasks.size();
                 },
                 VERIFICATION_TIMEOUT,
                 "Did not get all removed task within the given timeout!"
             );
-            assertEquals(expectedRemovedTasks.size(), removedTasks.size());
             assertTrue(removedTasks.stream()
                 .allMatch(task -> task.isActive() && task.state() == 
State.RESTORING
-                        || !task.isActive() && task.state() == State.RUNNING));
+                    || !task.isActive() && task.state() == State.RUNNING));
         }
     }
 
@@ -785,27 +1001,27 @@ class DefaultStateUpdaterTest {
         waitForCondition(
             () -> {
                 removedTasks.addAll(stateUpdater.drainRemovedTasks());
-                return removedTasks.containsAll(mkSet(tasks));
+                return removedTasks.containsAll(mkSet(tasks))
+                    && removedTasks.size() == expectedRemovedTasks.size();
             },
             VERIFICATION_TIMEOUT,
             "Did not get all restored active task within the given timeout!"
         );
-        assertEquals(expectedRemovedTasks.size(), removedTasks.size());
         assertTrue(stateUpdater.drainRemovedTasks().isEmpty());
     }
 
     private void verifyExceptionsAndFailedTasks(final ExceptionAndTasks... 
exceptionsAndTasks) throws Exception {
         final List<ExceptionAndTasks> expectedExceptionAndTasks = 
Arrays.asList(exceptionsAndTasks);
-        final List<ExceptionAndTasks> failedTasks = new ArrayList<>();
+        final Set<ExceptionAndTasks> failedTasks = new HashSet<>();
         waitForCondition(
             () -> {
                 failedTasks.addAll(stateUpdater.getExceptionsAndFailedTasks());
-                return failedTasks.containsAll(expectedExceptionAndTasks);
+                return failedTasks.containsAll(expectedExceptionAndTasks)
+                    && failedTasks.size() == expectedExceptionAndTasks.size();
             },
             VERIFICATION_TIMEOUT,
             "Did not get all exceptions and failed tasks within the given 
timeout!"
         );
-        assertEquals(expectedExceptionAndTasks.size(), failedTasks.size());
     }
 
     private void verifyDrainingExceptionsAndFailedTasks(final 
ExceptionAndTasks... exceptionsAndTasks) throws Exception {
@@ -814,12 +1030,12 @@ class DefaultStateUpdaterTest {
         waitForCondition(
             () -> {
                 
failedTasks.addAll(stateUpdater.drainExceptionsAndFailedTasks());
-                return failedTasks.containsAll(expectedExceptionAndTasks);
+                return failedTasks.containsAll(expectedExceptionAndTasks)
+                    && failedTasks.size() == expectedExceptionAndTasks.size();
             },
             VERIFICATION_TIMEOUT,
             "Did not get all exceptions and failed tasks within the given 
timeout!"
         );
-        assertEquals(expectedExceptionAndTasks.size(), failedTasks.size());
         assertTrue(stateUpdater.drainExceptionsAndFailedTasks().isEmpty());
     }
 

Reply via email to