This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new ced5989ff6 KAFKA-10199: Implement adding active tasks to the state 
updater (#12128)
ced5989ff6 is described below

commit ced5989ff69f8a5e76518fdeb39f41ab20b2574f
Author: Bruno Cadonna <cado...@apache.org>
AuthorDate: Fri May 6 01:00:35 2022 +0200

    KAFKA-10199: Implement adding active tasks to the state updater (#12128)
    
    This PR adds the default implementation of the state updater. The 
implementation only implements adding active tasks to the state updater.
    
    Reviewers: Guozhang Wang <wangg...@gmail.com>
---
 .../processor/internals/ChangelogReader.java       |   2 +
 .../processor/internals/DefaultStateUpdater.java   | 373 +++++++++++++++++++
 .../streams/processor/internals/StateUpdater.java  |  19 +-
 .../processor/internals/StoreChangelogReader.java  |   3 +-
 .../internals/DefaultStateUpdaterTest.java         | 408 +++++++++++++++++++++
 .../processor/internals/MockChangelogReader.java   |   5 +
 6 files changed, 804 insertions(+), 6 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java
index 9c62dd182e..38b00232c8 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java
@@ -46,6 +46,8 @@ public interface ChangelogReader extends ChangelogRegister {
      */
     Set<TopicPartition> completedChangelogs();
 
+    boolean allChangelogsCompleted();
+
     /**
      * Clear all partitions
      */
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
new file mode 100644
index 0000000000..0b6558d8ac
--- /dev/null
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
@@ -0,0 +1,373 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.errors.TaskCorruptedException;
+import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.Task.State;
+import org.slf4j.Logger;
+
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Queue;
+import java.util.Set;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
+import java.util.stream.Collectors;
+
+public class DefaultStateUpdater implements StateUpdater {
+
+    private final static String BUG_ERROR_MESSAGE = "This indicates a bug. " +
+        "Please report at https://issues.apache.org/jira/projects/KAFKA/issues 
or to the dev-mailing list (https://kafka.apache.org/contact).";
+
+    private class StateUpdaterThread extends Thread {
+
+        private final ChangelogReader changelogReader;
+        private final AtomicBoolean isRunning = new AtomicBoolean(true);
+        private final java.util.function.Consumer<Set<TopicPartition>> 
offsetResetter;
+        private final Map<TaskId, Task> updatingTasks = new HashMap<>();
+        private final Logger log;
+
+        public StateUpdaterThread(final String name,
+                                  final ChangelogReader changelogReader,
+                                  final 
java.util.function.Consumer<Set<TopicPartition>> offsetResetter) {
+            super(name);
+            this.changelogReader = changelogReader;
+            this.offsetResetter = offsetResetter;
+
+            final String logPrefix = String.format("%s ", name);
+            final LogContext logContext = new LogContext(logPrefix);
+            log = logContext.logger(DefaultStateUpdater.class);
+        }
+
+        public Collection<Task> getAllUpdatingTasks() {
+            return updatingTasks.values();
+        }
+
+        @Override
+        public void run() {
+            try {
+                while (isRunning.get()) {
+                    try {
+                        performActionsOnTasks();
+                        restoreTasks();
+                        waitIfAllChangelogsCompletelyRead();
+                    } catch (final InterruptedException interruptedException) {
+                        return;
+                    }
+                }
+            } catch (final RuntimeException anyOtherException) {
+                log.error("An unexpected error occurred within the state 
updater thread: " + anyOtherException);
+                final ExceptionAndTasks exceptionAndTasks = new 
ExceptionAndTasks(new HashSet<>(updatingTasks.values()), anyOtherException);
+                updatingTasks.clear();
+                failedTasks.add(exceptionAndTasks);
+                isRunning.set(false);
+            } finally {
+                clear();
+            }
+        }
+
+        private void performActionsOnTasks() throws InterruptedException {
+            tasksAndActionsLock.lock();
+            try {
+                for (final TaskAndAction taskAndAction : getTasksAndActions()) 
{
+                    final Task task = taskAndAction.task;
+                    final Action action = taskAndAction.action;
+                    switch (action) {
+                        case ADD:
+                            addTask(task);
+                            break;
+                    }
+                }
+            } finally {
+                tasksAndActionsLock.unlock();
+            }
+        }
+
+        private void restoreTasks() throws InterruptedException {
+            try {
+                // ToDo: Prioritize restoration of active tasks over standby 
tasks
+                //                changelogReader.enforceRestoreActive();
+                changelogReader.restore(updatingTasks);
+            } catch (final TaskCorruptedException taskCorruptedException) {
+                handleTaskCorruptedException(taskCorruptedException);
+            } catch (final StreamsException streamsException) {
+                handleStreamsException(streamsException);
+            }
+            final Set<TopicPartition> completedChangelogs = 
changelogReader.completedChangelogs();
+            final List<Task> activeTasks = 
updatingTasks.values().stream().filter(Task::isActive).collect(Collectors.toList());
+            for (final Task task : activeTasks) {
+                endRestorationIfChangelogsCompletelyRead(task, 
completedChangelogs);
+            }
+        }
+
+        private void handleTaskCorruptedException(final TaskCorruptedException 
taskCorruptedException) {
+            final Set<TaskId> corruptedTaskIds = 
taskCorruptedException.corruptedTasks();
+            final Set<Task> corruptedTasks = new HashSet<>();
+            for (final TaskId taskId : corruptedTaskIds) {
+                final Task corruptedTask = updatingTasks.remove(taskId);
+                if (corruptedTask == null) {
+                    throw new IllegalStateException("Task " + taskId + " is 
corrupted but is not updating. " + BUG_ERROR_MESSAGE);
+                }
+                corruptedTasks.add(corruptedTask);
+            }
+            failedTasks.add(new ExceptionAndTasks(corruptedTasks, 
taskCorruptedException));
+        }
+
+        private void handleStreamsException(final StreamsException 
streamsException) {
+            final ExceptionAndTasks exceptionAndTasks;
+            if (streamsException.taskId().isPresent()) {
+                exceptionAndTasks = 
handleStreamsExceptionWithTask(streamsException);
+            } else {
+                exceptionAndTasks = 
handleStreamsExceptionWithoutTask(streamsException);
+            }
+            failedTasks.add(exceptionAndTasks);
+        }
+
+        private ExceptionAndTasks handleStreamsExceptionWithTask(final 
StreamsException streamsException) {
+            final TaskId failedTaskId = streamsException.taskId().get();
+            if (!updatingTasks.containsKey(failedTaskId)) {
+                throw new IllegalStateException("Task " + failedTaskId + " 
failed but is not updating. " + BUG_ERROR_MESSAGE);
+            }
+            final Set<Task> failedTask = new HashSet<>();
+            failedTask.add(updatingTasks.get(failedTaskId));
+            updatingTasks.remove(failedTaskId);
+            return new ExceptionAndTasks(failedTask, streamsException);
+        }
+
+        private ExceptionAndTasks handleStreamsExceptionWithoutTask(final 
StreamsException streamsException) {
+            final ExceptionAndTasks exceptionAndTasks = new 
ExceptionAndTasks(new HashSet<>(updatingTasks.values()), streamsException);
+            updatingTasks.clear();
+            return exceptionAndTasks;
+        }
+
+        private void waitIfAllChangelogsCompletelyRead() throws 
InterruptedException {
+            if (isRunning.get() && changelogReader.allChangelogsCompleted()) {
+                tasksAndActionsLock.lock();
+                try {
+                    while (tasksAndActions.isEmpty()) {
+                        tasksAndActionsCondition.await();
+                    }
+                } finally {
+                    tasksAndActionsLock.unlock();
+                }
+            }
+        }
+
+        private void clear() {
+            tasksAndActionsLock.lock();
+            restoredActiveTasksLock.lock();
+            try {
+                tasksAndActions.clear();
+                restoredActiveTasks.clear();
+            } finally {
+                tasksAndActionsLock.unlock();
+                restoredActiveTasksLock.unlock();
+            }
+            changelogReader.clear();
+            updatingTasks.clear();
+        }
+
+        private List<TaskAndAction> getTasksAndActions() {
+            final List<TaskAndAction> tasksAndActionsToProcess = new 
ArrayList<>(tasksAndActions);
+            tasksAndActions.clear();
+            return tasksAndActionsToProcess;
+        }
+
+        private void addTask(final Task task) {
+            if (isStateless(task)) {
+                addTaskToRestoredTasks((StreamTask) task);
+            } else {
+                updatingTasks.put(task.id(), task);
+            }
+        }
+
+        private boolean isStateless(final Task task) {
+            return task.changelogPartitions().isEmpty() && task.isActive();
+        }
+
+        private void endRestorationIfChangelogsCompletelyRead(final Task task,
+                                                              final 
Set<TopicPartition> restoredChangelogs) {
+            final Collection<TopicPartition> taskChangelogPartitions = 
task.changelogPartitions();
+            if (restoredChangelogs.containsAll(taskChangelogPartitions)) {
+                task.completeRestoration(offsetResetter);
+                addTaskToRestoredTasks((StreamTask) task);
+                updatingTasks.remove(task.id());
+            }
+        }
+
+        private void addTaskToRestoredTasks(final StreamTask task) {
+            restoredActiveTasksLock.lock();
+            try {
+                restoredActiveTasks.add(task);
+                restoredActiveTasksCondition.signalAll();
+            } finally {
+                restoredActiveTasksLock.unlock();
+            }
+        }
+    }
+
+    enum Action {
+        ADD
+    }
+
+    private static class TaskAndAction {
+        public final Task task;
+        public final Action action;
+
+        public TaskAndAction(final Task task, final Action action) {
+            this.task = task;
+            this.action = action;
+        }
+    }
+
+    private final Time time;
+    private final ChangelogReader changelogReader;
+    private final java.util.function.Consumer<Set<TopicPartition>> 
offsetResetter;
+    private final Queue<TaskAndAction> tasksAndActions = new LinkedList<>();
+    private final Lock tasksAndActionsLock = new ReentrantLock();
+    private final Condition tasksAndActionsCondition = 
tasksAndActionsLock.newCondition();
+    private final Queue<StreamTask> restoredActiveTasks = new LinkedList<>();
+    private final Lock restoredActiveTasksLock = new ReentrantLock();
+    private final Condition restoredActiveTasksCondition = 
restoredActiveTasksLock.newCondition();
+    private final BlockingQueue<ExceptionAndTasks> failedTasks = new 
LinkedBlockingQueue<>();
+
+    private StateUpdaterThread stateUpdaterThread = null;
+
+    public DefaultStateUpdater(final ChangelogReader changelogReader,
+                               final 
java.util.function.Consumer<Set<TopicPartition>> offsetResetter,
+                               final Time time) {
+        this.changelogReader = changelogReader;
+        this.offsetResetter = offsetResetter;
+        this.time = time;
+    }
+
+    @Override
+    public void add(final Task task) {
+        if (stateUpdaterThread == null) {
+            stateUpdaterThread = new StateUpdaterThread("state-updater", 
changelogReader, offsetResetter);
+            stateUpdaterThread.start();
+        }
+
+        verifyStateFor(task);
+
+        tasksAndActionsLock.lock();
+        try {
+            tasksAndActions.add(new TaskAndAction(task, Action.ADD));
+            tasksAndActionsCondition.signalAll();
+        } finally {
+            tasksAndActionsLock.unlock();
+        }
+    }
+
+    private void verifyStateFor(final Task task) {
+        if (task.isActive() && task.state() != State.RESTORING) {
+            throw new IllegalStateException("Active task " + task.id() + " is 
not in state RESTORING. " + BUG_ERROR_MESSAGE);
+        }
+    }
+
+    @Override
+    public void remove(final Task task) {
+    }
+
+    @Override
+    public Set<StreamTask> getRestoredActiveTasks(final Duration timeout) {
+        final long timeoutMs = timeout.toMillis();
+        final long startTime = time.milliseconds();
+        final long deadline = startTime + timeoutMs;
+        long now = startTime;
+        final Set<StreamTask> result = new HashSet<>();
+        try {
+            while (now <= deadline && result.isEmpty()) {
+                restoredActiveTasksLock.lock();
+                try {
+                    while (restoredActiveTasks.isEmpty() && now <= deadline) {
+                        final boolean elapsed = 
restoredActiveTasksCondition.await(deadline - now, TimeUnit.MILLISECONDS);
+                        now = time.milliseconds();
+                    }
+                    while (!restoredActiveTasks.isEmpty()) {
+                        result.add(restoredActiveTasks.poll());
+                    }
+                } finally {
+                    restoredActiveTasksLock.unlock();
+                }
+                now = time.milliseconds();
+            }
+            return result;
+        } catch (final InterruptedException e) {
+            // ignore
+        }
+        return result;
+    }
+
+    @Override
+    public List<ExceptionAndTasks> getFailedTasksAndExceptions() {
+        final List<ExceptionAndTasks> result = new ArrayList<>();
+        failedTasks.drainTo(result);
+        return result;
+    }
+
+    @Override
+    public Set<Task> getAllTasks() {
+        tasksAndActionsLock.lock();
+        restoredActiveTasksLock.lock();
+        try {
+            final Set<Task> allTasks = new HashSet<>();
+            allTasks.addAll(tasksAndActions.stream()
+                .filter(t -> t.action == Action.ADD)
+                .map(t -> t.task)
+                .collect(Collectors.toList())
+            );
+            allTasks.addAll(stateUpdaterThread.getAllUpdatingTasks());
+            allTasks.addAll(restoredActiveTasks);
+            return Collections.unmodifiableSet(allTasks);
+        } finally {
+            tasksAndActionsLock.unlock();
+            restoredActiveTasksLock.unlock();
+        }
+    }
+
+    @Override
+    public void shutdown(final Duration timeout) {
+        if (stateUpdaterThread != null) {
+            stateUpdaterThread.isRunning.set(false);
+            stateUpdaterThread.interrupt();
+            try {
+                stateUpdaterThread.join(timeout.toMillis());
+                stateUpdaterThread = null;
+            } catch (final InterruptedException e) {
+                // ignore
+            }
+        }
+    }
+}
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
index 8965abfbe9..9e98e0d2c9 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
@@ -22,6 +22,16 @@ import java.util.Set;
 
 public interface StateUpdater {
 
+    class ExceptionAndTasks {
+        public final Set<Task> tasks;
+        public final RuntimeException exception;
+
+        public ExceptionAndTasks(final Set<Task> tasks, final RuntimeException 
exception) {
+            this.tasks = tasks;
+            this.exception = exception;
+        }
+    }
+
     /**
      * Adds a task (active or standby) to the state updater.
      *
@@ -41,17 +51,16 @@ public interface StateUpdater {
      *
      * @param timeout duration how long the calling thread should wait for 
restored active tasks
      *
-     * @return list of active tasks with up-to-date states
+     * @return set of active tasks with up-to-date states
      */
     Set<StreamTask> getRestoredActiveTasks(final Duration timeout);
 
     /**
-     * Gets a list of exceptions thrown during restoration.
+     * Gets failed tasks and the corresponding exceptions
      *
-     * @return exceptions
+     * @return list of failed tasks and the corresponding exceptions
      */
-    List<RuntimeException> getExceptions();
-
+    List<ExceptionAndTasks> getFailedTasksAndExceptions();
 
     /**
      * Get all tasks (active and standby) that are managed by the state 
updater.
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
index fdf027f2be..756bf11b0a 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
@@ -394,7 +394,8 @@ public class StoreChangelogReader implements 
ChangelogReader {
             .collect(Collectors.toSet());
     }
 
-    private boolean allChangelogsCompleted() {
+    @Override
+    public boolean allChangelogsCompleted() {
         return changelogs.values().stream()
             .allMatch(metadata -> metadata.changelogState == 
ChangelogState.COMPLETED);
     }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
new file mode 100644
index 0000000000..e94d8b1488
--- /dev/null
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
@@ -0,0 +1,408 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.errors.TaskCorruptedException;
+import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.internals.StateUpdater.ExceptionAndTasks;
+import org.apache.kafka.streams.processor.internals.Task.State;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Test;
+
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
+import static org.apache.kafka.common.utils.Utils.mkSet;
+import static org.apache.kafka.test.TestUtils.waitForCondition;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.anyMap;
+import static org.mockito.Mockito.atLeast;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+class DefaultStateUpdaterTest {
+
+    private final static long CALL_TIMEOUT = 1000;
+    private final static long VERIFICATION_TIMEOUT = 15000;
+    private final static TopicPartition TOPIC_PARTITION_A_0 = new 
TopicPartition("topicA", 0);
+    private final static TopicPartition TOPIC_PARTITION_B_0 = new 
TopicPartition("topicB", 0);
+    private final static TopicPartition TOPIC_PARTITION_C_0 = new 
TopicPartition("topicC", 0);
+    private final static TaskId TASK_0_0 = new TaskId(0, 0);
+    private final static TaskId TASK_0_2 = new TaskId(0, 2);
+    private final static TaskId TASK_1_0 = new TaskId(1, 0);
+
+    private final ChangelogReader changelogReader = 
mock(ChangelogReader.class);
+    private final java.util.function.Consumer<Set<TopicPartition>> 
offsetResetter = topicPartitions -> { };
+    private final DefaultStateUpdater stateUpdater = new 
DefaultStateUpdater(changelogReader, offsetResetter, Time.SYSTEM);
+
+    @AfterEach
+    public void tearDown() {
+        stateUpdater.shutdown(Duration.ofMinutes(1));
+    }
+
+    @Test
+    public void shouldShutdownStateUpdater() {
+        final StreamTask task = createStatelessTaskInStateRestoring(TASK_0_0);
+        stateUpdater.add(task);
+
+        stateUpdater.shutdown(Duration.ofMinutes(1));
+
+        verify(changelogReader).clear();
+    }
+
+    @Test
+    public void shouldShutdownStateUpdaterAndRestart() {
+        final StreamTask task1 = createStatelessTaskInStateRestoring(TASK_0_0);
+        stateUpdater.add(task1);
+
+        stateUpdater.shutdown(Duration.ofMinutes(1));
+
+        final StreamTask task2 = createStatelessTaskInStateRestoring(TASK_1_0);
+        stateUpdater.add(task2);
+
+        stateUpdater.shutdown(Duration.ofMinutes(1));
+
+        verify(changelogReader, times(2)).clear();
+    }
+
+    @Test
+    public void shouldThrowIfStatelessTaskNotInStateRestoring() {
+        shouldThrowIfTaskNotInStateRestoring(createStatelessTask(TASK_0_0));
+    }
+
+    @Test
+    public void shouldThrowIfStatefulTaskNotInStateRestoring() {
+        
shouldThrowIfTaskNotInStateRestoring(createActiveStatefulTask(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0)));
+    }
+
+    private void shouldThrowIfTaskNotInStateRestoring(final StreamTask task) {
+        when(task.state()).thenReturn(State.CREATED);
+        assertThrows(IllegalStateException.class, () -> 
stateUpdater.add(task));
+    }
+
+    @Test
+    public void shouldImmediatelyAddSingleStatelessTaskToRestoredTasks() 
throws Exception {
+        final StreamTask task1 = createStatelessTaskInStateRestoring(TASK_0_0);
+        shouldImmediatelyAddStatelessTasksToRestoredTasks(task1);
+    }
+
+    @Test
+    public void shouldImmediatelyAddMultipleStatelessTasksToRestoredTasks() 
throws Exception {
+        final StreamTask task1 = createStatelessTaskInStateRestoring(TASK_0_0);
+        final StreamTask task2 = createStatelessTaskInStateRestoring(TASK_0_2);
+        final StreamTask task3 = createStatelessTaskInStateRestoring(TASK_1_0);
+        shouldImmediatelyAddStatelessTasksToRestoredTasks(task1, task2, task3);
+    }
+
+    private void shouldImmediatelyAddStatelessTasksToRestoredTasks(final 
StreamTask... tasks) throws Exception {
+        for (final StreamTask task : tasks) {
+            stateUpdater.add(task);
+        }
+
+        final Set<StreamTask> expectedRestoredTasks = mkSet(tasks);
+        final Set<StreamTask> restoredTasks = new HashSet<>();
+        waitForCondition(
+            () -> {
+                
restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT)));
+                return restoredTasks.size() == expectedRestoredTasks.size();
+            },
+            VERIFICATION_TIMEOUT,
+            "Did not get any restored active task within the given timeout!"
+        );
+        assertTrue(restoredTasks.containsAll(expectedRestoredTasks));
+        assertEquals(expectedRestoredTasks.size(), 
restoredTasks.stream().filter(task -> task.state() == State.RESTORING).count());
+        assertTrue(stateUpdater.getAllTasks().isEmpty());
+    }
+
+    @Test
+    public void shouldRestoreSingleActiveStatefulTask() throws Exception {
+        final StreamTask task =
+            createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Arrays.asList(TOPIC_PARTITION_A_0, TOPIC_PARTITION_B_0));
+        when(changelogReader.completedChangelogs())
+            .thenReturn(Collections.emptySet())
+            .thenReturn(mkSet(TOPIC_PARTITION_A_0))
+            .thenReturn(mkSet(TOPIC_PARTITION_A_0, TOPIC_PARTITION_B_0));
+        when(changelogReader.allChangelogsCompleted())
+            .thenReturn(false)
+            .thenReturn(false)
+            .thenReturn(true);
+
+        stateUpdater.add(task);
+
+        final Set<StreamTask> expectedRestoredTasks = 
Collections.singleton(task);
+        final Set<StreamTask> restoredTasks = new HashSet<>();
+        waitForCondition(
+            () -> {
+                
restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT)));
+                return restoredTasks.size() == expectedRestoredTasks.size();
+            },
+            VERIFICATION_TIMEOUT,
+            "Did not get any restored active task within the given timeout!"
+        );
+        assertTrue(restoredTasks.containsAll(expectedRestoredTasks));
+        assertEquals(expectedRestoredTasks.size(), 
restoredTasks.stream().filter(t -> t.state() == State.RESTORING).count());
+        assertTrue(stateUpdater.getAllTasks().isEmpty());
+        verify(changelogReader, atLeast(3)).restore(anyMap());
+        verify(task).completeRestoration(offsetResetter);
+    }
+
+    @Test
+    public void shouldRestoreMultipleActiveStatefulTasks() throws Exception {
+        final StreamTask task1 = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
+        final StreamTask task2 = 
createActiveStatefulTaskInStateRestoring(TASK_0_2, 
Collections.singletonList(TOPIC_PARTITION_B_0));
+        final StreamTask task3 = 
createActiveStatefulTaskInStateRestoring(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_C_0));
+        when(changelogReader.completedChangelogs())
+            .thenReturn(Collections.emptySet())
+            .thenReturn(mkSet(TOPIC_PARTITION_C_0))
+            .thenReturn(mkSet(TOPIC_PARTITION_C_0, TOPIC_PARTITION_A_0))
+            .thenReturn(mkSet(TOPIC_PARTITION_C_0, TOPIC_PARTITION_A_0, 
TOPIC_PARTITION_B_0));
+        when(changelogReader.allChangelogsCompleted())
+            .thenReturn(false)
+            .thenReturn(false)
+            .thenReturn(false)
+            .thenReturn(true);
+
+        stateUpdater.add(task1);
+        stateUpdater.add(task2);
+        stateUpdater.add(task3);
+
+        final Set<StreamTask> expectedRestoredTasks = mkSet(task3, task1, 
task2);
+        final Set<StreamTask> restoredTasks = new HashSet<>();
+        waitForCondition(
+            () -> {
+                
restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT)));
+                return restoredTasks.size() == expectedRestoredTasks.size();
+            },
+            VERIFICATION_TIMEOUT,
+            "Did not get any restored active task within the given timeout!"
+        );
+        assertTrue(restoredTasks.containsAll(expectedRestoredTasks));
+        assertEquals(expectedRestoredTasks.size(), 
restoredTasks.stream().filter(t -> t.state() == State.RESTORING).count());
+        assertTrue(stateUpdater.getAllTasks().isEmpty());
+        verify(changelogReader, atLeast(4)).restore(anyMap());
+        verify(task3).completeRestoration(offsetResetter);
+        verify(task1).completeRestoration(offsetResetter);
+        verify(task2).completeRestoration(offsetResetter);
+    }
+
+    @Test
+    public void 
shouldAddFailedTasksToQueueWhenRestoreThrowsStreamsExceptionWithoutTask() 
throws Exception {
+        final StreamTask task1 = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
+        final StreamTask task2 = 
createActiveStatefulTaskInStateRestoring(TASK_0_2, 
Collections.singletonList(TOPIC_PARTITION_B_0));
+        final StreamTask task3 = 
createActiveStatefulTaskInStateRestoring(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_C_0));
+        final String expectedMessage = "The Streams were crossed!";
+        final StreamsException expectedStreamsException = new 
StreamsException(expectedMessage);
+        final Map<TaskId, Task> updatingTasks = mkMap(
+            mkEntry(task1.id(), task1),
+            mkEntry(task2.id(), task2),
+            mkEntry(task3.id(), task3)
+        );
+        
doNothing().doThrow(expectedStreamsException).doNothing().when(changelogReader).restore(updatingTasks);
+
+        stateUpdater.add(task1);
+        stateUpdater.add(task2);
+        stateUpdater.add(task3);
+
+        final List<ExceptionAndTasks> failedTasks = getFailedTasks(1);
+        assertEquals(1, failedTasks.size());
+        final ExceptionAndTasks actualFailedTasks = failedTasks.get(0);
+        assertEquals(3, actualFailedTasks.tasks.size());
+        assertTrue(actualFailedTasks.tasks.containsAll(Arrays.asList(task1, 
task2, task3)));
+        assertTrue(actualFailedTasks.exception instanceof StreamsException);
+        final StreamsException actualException = (StreamsException) 
actualFailedTasks.exception;
+        assertFalse(actualException.taskId().isPresent());
+        assertEquals(expectedMessage, actualException.getMessage());
+        assertTrue(stateUpdater.getAllTasks().isEmpty());
+    }
+
+    @Test
+    public void 
shouldAddFailedTasksToQueueWhenRestoreThrowsStreamsExceptionWithTask() throws 
Exception {
+        final StreamTask task1 = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
+        final StreamTask task2 = 
createActiveStatefulTaskInStateRestoring(TASK_0_2, 
Collections.singletonList(TOPIC_PARTITION_B_0));
+        final StreamTask task3 = 
createActiveStatefulTaskInStateRestoring(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_C_0));
+        final String expectedMessage = "The Streams were crossed!";
+        final StreamsException expectedStreamsException1 = new 
StreamsException(expectedMessage, task1.id());
+        final StreamsException expectedStreamsException2 = new 
StreamsException(expectedMessage, task3.id());
+        final Map<TaskId, Task> updatingTasksBeforeFirstThrow = mkMap(
+            mkEntry(task1.id(), task1),
+            mkEntry(task2.id(), task2),
+            mkEntry(task3.id(), task3)
+        );
+        final Map<TaskId, Task> updatingTasksBeforeSecondThrow = mkMap(
+            mkEntry(task2.id(), task2),
+            mkEntry(task3.id(), task3)
+        );
+        doNothing()
+            .doThrow(expectedStreamsException1)
+            .when(changelogReader).restore(updatingTasksBeforeFirstThrow);
+        doNothing()
+            .doThrow(expectedStreamsException2)
+            .when(changelogReader).restore(updatingTasksBeforeSecondThrow);
+
+        stateUpdater.add(task1);
+        stateUpdater.add(task2);
+        stateUpdater.add(task3);
+
+        final List<ExceptionAndTasks> failedTasks = getFailedTasks(2);
+        assertEquals(2, failedTasks.size());
+        final ExceptionAndTasks actualFailedTasks1 = failedTasks.get(0);
+        assertEquals(1, actualFailedTasks1.tasks.size());
+        assertTrue(actualFailedTasks1.tasks.contains(task1));
+        assertTrue(actualFailedTasks1.exception instanceof StreamsException);
+        final StreamsException actualException1 = (StreamsException) 
actualFailedTasks1.exception;
+        assertTrue(actualException1.taskId().isPresent());
+        assertEquals(task1.id(), actualException1.taskId().get());
+        assertEquals(expectedMessage, actualException1.getMessage());
+        final ExceptionAndTasks actualFailedTasks2 = failedTasks.get(1);
+        assertEquals(1, actualFailedTasks2.tasks.size());
+        assertTrue(actualFailedTasks2.tasks.contains(task3));
+        assertTrue(actualFailedTasks2.exception instanceof StreamsException);
+        final StreamsException actualException2 = (StreamsException) 
actualFailedTasks2.exception;
+        assertTrue(actualException2.taskId().isPresent());
+        assertEquals(task3.id(), actualException2.taskId().get());
+        assertEquals(expectedMessage, actualException2.getMessage());
+        assertEquals(1, stateUpdater.getAllTasks().size());
+        assertTrue(stateUpdater.getAllTasks().contains(task2));
+    }
+
+    @Test
+    public void 
shouldAddFailedTasksToQueueWhenRestoreThrowsTaskCorruptedException() throws 
Exception {
+        final StreamTask task1 = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
+        final StreamTask task2 = 
createActiveStatefulTaskInStateRestoring(TASK_0_2, 
Collections.singletonList(TOPIC_PARTITION_B_0));
+        final StreamTask task3 = 
createActiveStatefulTaskInStateRestoring(TASK_1_0, 
Collections.singletonList(TOPIC_PARTITION_C_0));
+        final Set<TaskId> expectedTaskIds = mkSet(task1.id(), task2.id());
+        final TaskCorruptedException taskCorruptedException = new 
TaskCorruptedException(expectedTaskIds);
+        final Map<TaskId, Task> updatingTasks = mkMap(
+            mkEntry(task1.id(), task1),
+            mkEntry(task2.id(), task2),
+            mkEntry(task3.id(), task3)
+        );
+        
doNothing().doThrow(taskCorruptedException).doNothing().when(changelogReader).restore(updatingTasks);
+
+        stateUpdater.add(task1);
+        stateUpdater.add(task2);
+        stateUpdater.add(task3);
+
+        final List<ExceptionAndTasks> failedTasks = getFailedTasks(1);
+        assertEquals(1, failedTasks.size());
+        final List<Task> expectedFailedTasks = Arrays.asList(task1, task2);
+        final ExceptionAndTasks actualFailedTasks = failedTasks.get(0);
+        assertEquals(2, actualFailedTasks.tasks.size());
+        assertTrue(actualFailedTasks.tasks.containsAll(expectedFailedTasks));
+        assertTrue(actualFailedTasks.exception instanceof 
TaskCorruptedException);
+        final TaskCorruptedException actualException = 
(TaskCorruptedException) actualFailedTasks.exception;
+        final Set<TaskId> corruptedTasks = actualException.corruptedTasks();
+        
assertTrue(corruptedTasks.containsAll(expectedFailedTasks.stream().map(Task::id).collect(Collectors.toList())));
+        assertEquals(1, stateUpdater.getAllTasks().size());
+        assertTrue(stateUpdater.getAllTasks().contains(task3));
+    }
+
+    @Test
+    public void shouldAddFailedTasksToQueueWhenUncaughtExceptionIsThrown() 
throws Exception {
+        final StreamTask task1 = 
createActiveStatefulTaskInStateRestoring(TASK_0_0, 
Collections.singletonList(TOPIC_PARTITION_A_0));
+        final StreamTask task2 = 
createActiveStatefulTaskInStateRestoring(TASK_0_2, 
Collections.singletonList(TOPIC_PARTITION_B_0));
+        final IllegalStateException illegalStateException = new 
IllegalStateException("Nobody expects the Spanish inquisition!");
+        final Map<TaskId, Task> updatingTasks = mkMap(
+            mkEntry(task1.id(), task1),
+            mkEntry(task2.id(), task2)
+        );
+        
doThrow(illegalStateException).when(changelogReader).restore(updatingTasks);
+
+        stateUpdater.add(task1);
+        stateUpdater.add(task2);
+
+        final List<ExceptionAndTasks> failedTasks = getFailedTasks(1);
+        final List<Task> expectedFailedTasks = Arrays.asList(task1, task2);
+        final ExceptionAndTasks actualFailedTasks = failedTasks.get(0);
+        assertEquals(2, actualFailedTasks.tasks.size());
+        assertTrue(actualFailedTasks.tasks.containsAll(expectedFailedTasks));
+        assertTrue(actualFailedTasks.exception instanceof 
IllegalStateException);
+        final IllegalStateException actualException = (IllegalStateException) 
actualFailedTasks.exception;
+        assertEquals(actualException.getMessage(), 
illegalStateException.getMessage());
+        assertTrue(stateUpdater.getAllTasks().isEmpty());
+    }
+
+    private List<ExceptionAndTasks> getFailedTasks(final int expectedCount) 
throws Exception {
+        final List<ExceptionAndTasks> failedTasks = new ArrayList<>();
+        waitForCondition(
+            () -> {
+                failedTasks.addAll(stateUpdater.getFailedTasksAndExceptions());
+                return failedTasks.size() >= expectedCount;
+            },
+            VERIFICATION_TIMEOUT,
+            "Did not get enough failed tasks within the given timeout!"
+        );
+
+        return failedTasks;
+    }
+
+    private StreamTask createActiveStatefulTaskInStateRestoring(final TaskId 
taskId,
+                                                                final 
Collection<TopicPartition> changelogPartitions) {
+        final StreamTask task = createActiveStatefulTask(taskId, 
changelogPartitions);
+        when(task.state()).thenReturn(State.RESTORING);
+        return task;
+    }
+
+    private StreamTask createActiveStatefulTask(final TaskId taskId,
+                                                final 
Collection<TopicPartition> changelogPartitions) {
+        final StreamTask task = mock(StreamTask.class);
+        setupStatefulTask(task, taskId, changelogPartitions);
+        when(task.isActive()).thenReturn(true);
+        return task;
+    }
+
+    private StreamTask createStatelessTaskInStateRestoring(final TaskId 
taskId) {
+        final StreamTask task = createStatelessTask(taskId);
+        when(task.state()).thenReturn(State.RESTORING);
+        return task;
+    }
+
+    private StreamTask createStatelessTask(final TaskId taskId) {
+        final StreamTask task = mock(StreamTask.class);
+        when(task.changelogPartitions()).thenReturn(Collections.emptySet());
+        when(task.isActive()).thenReturn(true);
+        when(task.id()).thenReturn(taskId);
+        return task;
+    }
+
+    private void setupStatefulTask(final Task task,
+                                   final TaskId taskId,
+                                   final Collection<TopicPartition> 
changelogPartitions) {
+        when(task.changelogPartitions()).thenReturn(changelogPartitions);
+        when(task.id()).thenReturn(taskId);
+    }
+}
\ No newline at end of file
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java
index 6ea7fc3101..d86728891c 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java
@@ -59,6 +59,11 @@ public class MockChangelogReader implements ChangelogReader {
         return restoringPartitions;
     }
 
+    @Override
+    public boolean allChangelogsCompleted() {
+        return false;
+    }
+
     @Override
     public void clear() {
         restoringPartitions.clear();

Reply via email to