This is an automated email from the ASF dual-hosted git repository. guozhang pushed a commit to branch trunk in repository https://gitbox.apache.org/repos/asf/kafka.git
The following commit(s) were added to refs/heads/trunk by this push: new ced5989ff6 KAFKA-10199: Implement adding active tasks to the state updater (#12128) ced5989ff6 is described below commit ced5989ff69f8a5e76518fdeb39f41ab20b2574f Author: Bruno Cadonna <cado...@apache.org> AuthorDate: Fri May 6 01:00:35 2022 +0200 KAFKA-10199: Implement adding active tasks to the state updater (#12128) This PR adds the default implementation of the state updater. The implementation only implements adding active tasks to the state updater. Reviewers: Guozhang Wang <wangg...@gmail.com> --- .../processor/internals/ChangelogReader.java | 2 + .../processor/internals/DefaultStateUpdater.java | 373 +++++++++++++++++++ .../streams/processor/internals/StateUpdater.java | 19 +- .../processor/internals/StoreChangelogReader.java | 3 +- .../internals/DefaultStateUpdaterTest.java | 408 +++++++++++++++++++++ .../processor/internals/MockChangelogReader.java | 5 + 6 files changed, 804 insertions(+), 6 deletions(-) diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java index 9c62dd182e..38b00232c8 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java @@ -46,6 +46,8 @@ public interface ChangelogReader extends ChangelogRegister { */ Set<TopicPartition> completedChangelogs(); + boolean allChangelogsCompleted(); + /** * Clear all partitions */ diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java new file mode 100644 index 0000000000..0b6558d8ac --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java @@ -0,0 +1,373 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskCorruptedException; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.Task.State; +import org.slf4j.Logger; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.stream.Collectors; + +public class DefaultStateUpdater implements StateUpdater { + + private final static String BUG_ERROR_MESSAGE = "This indicates a bug. " + + "Please report at https://issues.apache.org/jira/projects/KAFKA/issues or to the dev-mailing list (https://kafka.apache.org/contact)."; + + private class StateUpdaterThread extends Thread { + + private final ChangelogReader changelogReader; + private final AtomicBoolean isRunning = new AtomicBoolean(true); + private final java.util.function.Consumer<Set<TopicPartition>> offsetResetter; + private final Map<TaskId, Task> updatingTasks = new HashMap<>(); + private final Logger log; + + public StateUpdaterThread(final String name, + final ChangelogReader changelogReader, + final java.util.function.Consumer<Set<TopicPartition>> offsetResetter) { + super(name); + this.changelogReader = changelogReader; + this.offsetResetter = offsetResetter; + + final String logPrefix = String.format("%s ", name); + final LogContext logContext = new LogContext(logPrefix); + log = logContext.logger(DefaultStateUpdater.class); + } + + public Collection<Task> getAllUpdatingTasks() { + return updatingTasks.values(); + } + + @Override + public void run() { + try { + while (isRunning.get()) { + try { + performActionsOnTasks(); + restoreTasks(); + waitIfAllChangelogsCompletelyRead(); + } catch (final InterruptedException interruptedException) { + return; + } + } + } catch (final RuntimeException anyOtherException) { + log.error("An unexpected error occurred within the state updater thread: " + anyOtherException); + final ExceptionAndTasks exceptionAndTasks = new ExceptionAndTasks(new HashSet<>(updatingTasks.values()), anyOtherException); + updatingTasks.clear(); + failedTasks.add(exceptionAndTasks); + isRunning.set(false); + } finally { + clear(); + } + } + + private void performActionsOnTasks() throws InterruptedException { + tasksAndActionsLock.lock(); + try { + for (final TaskAndAction taskAndAction : getTasksAndActions()) { + final Task task = taskAndAction.task; + final Action action = taskAndAction.action; + switch (action) { + case ADD: + addTask(task); + break; + } + } + } finally { + tasksAndActionsLock.unlock(); + } + } + + private void restoreTasks() throws InterruptedException { + try { + // ToDo: Prioritize restoration of active tasks over standby tasks + // changelogReader.enforceRestoreActive(); + changelogReader.restore(updatingTasks); + } catch (final TaskCorruptedException taskCorruptedException) { + handleTaskCorruptedException(taskCorruptedException); + } catch (final StreamsException streamsException) { + handleStreamsException(streamsException); + } + final Set<TopicPartition> completedChangelogs = changelogReader.completedChangelogs(); + final List<Task> activeTasks = updatingTasks.values().stream().filter(Task::isActive).collect(Collectors.toList()); + for (final Task task : activeTasks) { + endRestorationIfChangelogsCompletelyRead(task, completedChangelogs); + } + } + + private void handleTaskCorruptedException(final TaskCorruptedException taskCorruptedException) { + final Set<TaskId> corruptedTaskIds = taskCorruptedException.corruptedTasks(); + final Set<Task> corruptedTasks = new HashSet<>(); + for (final TaskId taskId : corruptedTaskIds) { + final Task corruptedTask = updatingTasks.remove(taskId); + if (corruptedTask == null) { + throw new IllegalStateException("Task " + taskId + " is corrupted but is not updating. " + BUG_ERROR_MESSAGE); + } + corruptedTasks.add(corruptedTask); + } + failedTasks.add(new ExceptionAndTasks(corruptedTasks, taskCorruptedException)); + } + + private void handleStreamsException(final StreamsException streamsException) { + final ExceptionAndTasks exceptionAndTasks; + if (streamsException.taskId().isPresent()) { + exceptionAndTasks = handleStreamsExceptionWithTask(streamsException); + } else { + exceptionAndTasks = handleStreamsExceptionWithoutTask(streamsException); + } + failedTasks.add(exceptionAndTasks); + } + + private ExceptionAndTasks handleStreamsExceptionWithTask(final StreamsException streamsException) { + final TaskId failedTaskId = streamsException.taskId().get(); + if (!updatingTasks.containsKey(failedTaskId)) { + throw new IllegalStateException("Task " + failedTaskId + " failed but is not updating. " + BUG_ERROR_MESSAGE); + } + final Set<Task> failedTask = new HashSet<>(); + failedTask.add(updatingTasks.get(failedTaskId)); + updatingTasks.remove(failedTaskId); + return new ExceptionAndTasks(failedTask, streamsException); + } + + private ExceptionAndTasks handleStreamsExceptionWithoutTask(final StreamsException streamsException) { + final ExceptionAndTasks exceptionAndTasks = new ExceptionAndTasks(new HashSet<>(updatingTasks.values()), streamsException); + updatingTasks.clear(); + return exceptionAndTasks; + } + + private void waitIfAllChangelogsCompletelyRead() throws InterruptedException { + if (isRunning.get() && changelogReader.allChangelogsCompleted()) { + tasksAndActionsLock.lock(); + try { + while (tasksAndActions.isEmpty()) { + tasksAndActionsCondition.await(); + } + } finally { + tasksAndActionsLock.unlock(); + } + } + } + + private void clear() { + tasksAndActionsLock.lock(); + restoredActiveTasksLock.lock(); + try { + tasksAndActions.clear(); + restoredActiveTasks.clear(); + } finally { + tasksAndActionsLock.unlock(); + restoredActiveTasksLock.unlock(); + } + changelogReader.clear(); + updatingTasks.clear(); + } + + private List<TaskAndAction> getTasksAndActions() { + final List<TaskAndAction> tasksAndActionsToProcess = new ArrayList<>(tasksAndActions); + tasksAndActions.clear(); + return tasksAndActionsToProcess; + } + + private void addTask(final Task task) { + if (isStateless(task)) { + addTaskToRestoredTasks((StreamTask) task); + } else { + updatingTasks.put(task.id(), task); + } + } + + private boolean isStateless(final Task task) { + return task.changelogPartitions().isEmpty() && task.isActive(); + } + + private void endRestorationIfChangelogsCompletelyRead(final Task task, + final Set<TopicPartition> restoredChangelogs) { + final Collection<TopicPartition> taskChangelogPartitions = task.changelogPartitions(); + if (restoredChangelogs.containsAll(taskChangelogPartitions)) { + task.completeRestoration(offsetResetter); + addTaskToRestoredTasks((StreamTask) task); + updatingTasks.remove(task.id()); + } + } + + private void addTaskToRestoredTasks(final StreamTask task) { + restoredActiveTasksLock.lock(); + try { + restoredActiveTasks.add(task); + restoredActiveTasksCondition.signalAll(); + } finally { + restoredActiveTasksLock.unlock(); + } + } + } + + enum Action { + ADD + } + + private static class TaskAndAction { + public final Task task; + public final Action action; + + public TaskAndAction(final Task task, final Action action) { + this.task = task; + this.action = action; + } + } + + private final Time time; + private final ChangelogReader changelogReader; + private final java.util.function.Consumer<Set<TopicPartition>> offsetResetter; + private final Queue<TaskAndAction> tasksAndActions = new LinkedList<>(); + private final Lock tasksAndActionsLock = new ReentrantLock(); + private final Condition tasksAndActionsCondition = tasksAndActionsLock.newCondition(); + private final Queue<StreamTask> restoredActiveTasks = new LinkedList<>(); + private final Lock restoredActiveTasksLock = new ReentrantLock(); + private final Condition restoredActiveTasksCondition = restoredActiveTasksLock.newCondition(); + private final BlockingQueue<ExceptionAndTasks> failedTasks = new LinkedBlockingQueue<>(); + + private StateUpdaterThread stateUpdaterThread = null; + + public DefaultStateUpdater(final ChangelogReader changelogReader, + final java.util.function.Consumer<Set<TopicPartition>> offsetResetter, + final Time time) { + this.changelogReader = changelogReader; + this.offsetResetter = offsetResetter; + this.time = time; + } + + @Override + public void add(final Task task) { + if (stateUpdaterThread == null) { + stateUpdaterThread = new StateUpdaterThread("state-updater", changelogReader, offsetResetter); + stateUpdaterThread.start(); + } + + verifyStateFor(task); + + tasksAndActionsLock.lock(); + try { + tasksAndActions.add(new TaskAndAction(task, Action.ADD)); + tasksAndActionsCondition.signalAll(); + } finally { + tasksAndActionsLock.unlock(); + } + } + + private void verifyStateFor(final Task task) { + if (task.isActive() && task.state() != State.RESTORING) { + throw new IllegalStateException("Active task " + task.id() + " is not in state RESTORING. " + BUG_ERROR_MESSAGE); + } + } + + @Override + public void remove(final Task task) { + } + + @Override + public Set<StreamTask> getRestoredActiveTasks(final Duration timeout) { + final long timeoutMs = timeout.toMillis(); + final long startTime = time.milliseconds(); + final long deadline = startTime + timeoutMs; + long now = startTime; + final Set<StreamTask> result = new HashSet<>(); + try { + while (now <= deadline && result.isEmpty()) { + restoredActiveTasksLock.lock(); + try { + while (restoredActiveTasks.isEmpty() && now <= deadline) { + final boolean elapsed = restoredActiveTasksCondition.await(deadline - now, TimeUnit.MILLISECONDS); + now = time.milliseconds(); + } + while (!restoredActiveTasks.isEmpty()) { + result.add(restoredActiveTasks.poll()); + } + } finally { + restoredActiveTasksLock.unlock(); + } + now = time.milliseconds(); + } + return result; + } catch (final InterruptedException e) { + // ignore + } + return result; + } + + @Override + public List<ExceptionAndTasks> getFailedTasksAndExceptions() { + final List<ExceptionAndTasks> result = new ArrayList<>(); + failedTasks.drainTo(result); + return result; + } + + @Override + public Set<Task> getAllTasks() { + tasksAndActionsLock.lock(); + restoredActiveTasksLock.lock(); + try { + final Set<Task> allTasks = new HashSet<>(); + allTasks.addAll(tasksAndActions.stream() + .filter(t -> t.action == Action.ADD) + .map(t -> t.task) + .collect(Collectors.toList()) + ); + allTasks.addAll(stateUpdaterThread.getAllUpdatingTasks()); + allTasks.addAll(restoredActiveTasks); + return Collections.unmodifiableSet(allTasks); + } finally { + tasksAndActionsLock.unlock(); + restoredActiveTasksLock.unlock(); + } + } + + @Override + public void shutdown(final Duration timeout) { + if (stateUpdaterThread != null) { + stateUpdaterThread.isRunning.set(false); + stateUpdaterThread.interrupt(); + try { + stateUpdaterThread.join(timeout.toMillis()); + stateUpdaterThread = null; + } catch (final InterruptedException e) { + // ignore + } + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java index 8965abfbe9..9e98e0d2c9 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java @@ -22,6 +22,16 @@ import java.util.Set; public interface StateUpdater { + class ExceptionAndTasks { + public final Set<Task> tasks; + public final RuntimeException exception; + + public ExceptionAndTasks(final Set<Task> tasks, final RuntimeException exception) { + this.tasks = tasks; + this.exception = exception; + } + } + /** * Adds a task (active or standby) to the state updater. * @@ -41,17 +51,16 @@ public interface StateUpdater { * * @param timeout duration how long the calling thread should wait for restored active tasks * - * @return list of active tasks with up-to-date states + * @return set of active tasks with up-to-date states */ Set<StreamTask> getRestoredActiveTasks(final Duration timeout); /** - * Gets a list of exceptions thrown during restoration. + * Gets failed tasks and the corresponding exceptions * - * @return exceptions + * @return list of failed tasks and the corresponding exceptions */ - List<RuntimeException> getExceptions(); - + List<ExceptionAndTasks> getFailedTasksAndExceptions(); /** * Get all tasks (active and standby) that are managed by the state updater. diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java index fdf027f2be..756bf11b0a 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java @@ -394,7 +394,8 @@ public class StoreChangelogReader implements ChangelogReader { .collect(Collectors.toSet()); } - private boolean allChangelogsCompleted() { + @Override + public boolean allChangelogsCompleted() { return changelogs.values().stream() .allMatch(metadata -> metadata.changelogState == ChangelogState.COMPLETED); } diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java new file mode 100644 index 0000000000..e94d8b1488 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java @@ -0,0 +1,408 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskCorruptedException; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.StateUpdater.ExceptionAndTasks; +import org.apache.kafka.streams.processor.internals.Task.State; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.test.TestUtils.waitForCondition; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class DefaultStateUpdaterTest { + + private final static long CALL_TIMEOUT = 1000; + private final static long VERIFICATION_TIMEOUT = 15000; + private final static TopicPartition TOPIC_PARTITION_A_0 = new TopicPartition("topicA", 0); + private final static TopicPartition TOPIC_PARTITION_B_0 = new TopicPartition("topicB", 0); + private final static TopicPartition TOPIC_PARTITION_C_0 = new TopicPartition("topicC", 0); + private final static TaskId TASK_0_0 = new TaskId(0, 0); + private final static TaskId TASK_0_2 = new TaskId(0, 2); + private final static TaskId TASK_1_0 = new TaskId(1, 0); + + private final ChangelogReader changelogReader = mock(ChangelogReader.class); + private final java.util.function.Consumer<Set<TopicPartition>> offsetResetter = topicPartitions -> { }; + private final DefaultStateUpdater stateUpdater = new DefaultStateUpdater(changelogReader, offsetResetter, Time.SYSTEM); + + @AfterEach + public void tearDown() { + stateUpdater.shutdown(Duration.ofMinutes(1)); + } + + @Test + public void shouldShutdownStateUpdater() { + final StreamTask task = createStatelessTaskInStateRestoring(TASK_0_0); + stateUpdater.add(task); + + stateUpdater.shutdown(Duration.ofMinutes(1)); + + verify(changelogReader).clear(); + } + + @Test + public void shouldShutdownStateUpdaterAndRestart() { + final StreamTask task1 = createStatelessTaskInStateRestoring(TASK_0_0); + stateUpdater.add(task1); + + stateUpdater.shutdown(Duration.ofMinutes(1)); + + final StreamTask task2 = createStatelessTaskInStateRestoring(TASK_1_0); + stateUpdater.add(task2); + + stateUpdater.shutdown(Duration.ofMinutes(1)); + + verify(changelogReader, times(2)).clear(); + } + + @Test + public void shouldThrowIfStatelessTaskNotInStateRestoring() { + shouldThrowIfTaskNotInStateRestoring(createStatelessTask(TASK_0_0)); + } + + @Test + public void shouldThrowIfStatefulTaskNotInStateRestoring() { + shouldThrowIfTaskNotInStateRestoring(createActiveStatefulTask(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0))); + } + + private void shouldThrowIfTaskNotInStateRestoring(final StreamTask task) { + when(task.state()).thenReturn(State.CREATED); + assertThrows(IllegalStateException.class, () -> stateUpdater.add(task)); + } + + @Test + public void shouldImmediatelyAddSingleStatelessTaskToRestoredTasks() throws Exception { + final StreamTask task1 = createStatelessTaskInStateRestoring(TASK_0_0); + shouldImmediatelyAddStatelessTasksToRestoredTasks(task1); + } + + @Test + public void shouldImmediatelyAddMultipleStatelessTasksToRestoredTasks() throws Exception { + final StreamTask task1 = createStatelessTaskInStateRestoring(TASK_0_0); + final StreamTask task2 = createStatelessTaskInStateRestoring(TASK_0_2); + final StreamTask task3 = createStatelessTaskInStateRestoring(TASK_1_0); + shouldImmediatelyAddStatelessTasksToRestoredTasks(task1, task2, task3); + } + + private void shouldImmediatelyAddStatelessTasksToRestoredTasks(final StreamTask... tasks) throws Exception { + for (final StreamTask task : tasks) { + stateUpdater.add(task); + } + + final Set<StreamTask> expectedRestoredTasks = mkSet(tasks); + final Set<StreamTask> restoredTasks = new HashSet<>(); + waitForCondition( + () -> { + restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT))); + return restoredTasks.size() == expectedRestoredTasks.size(); + }, + VERIFICATION_TIMEOUT, + "Did not get any restored active task within the given timeout!" + ); + assertTrue(restoredTasks.containsAll(expectedRestoredTasks)); + assertEquals(expectedRestoredTasks.size(), restoredTasks.stream().filter(task -> task.state() == State.RESTORING).count()); + assertTrue(stateUpdater.getAllTasks().isEmpty()); + } + + @Test + public void shouldRestoreSingleActiveStatefulTask() throws Exception { + final StreamTask task = + createActiveStatefulTaskInStateRestoring(TASK_0_0, Arrays.asList(TOPIC_PARTITION_A_0, TOPIC_PARTITION_B_0)); + when(changelogReader.completedChangelogs()) + .thenReturn(Collections.emptySet()) + .thenReturn(mkSet(TOPIC_PARTITION_A_0)) + .thenReturn(mkSet(TOPIC_PARTITION_A_0, TOPIC_PARTITION_B_0)); + when(changelogReader.allChangelogsCompleted()) + .thenReturn(false) + .thenReturn(false) + .thenReturn(true); + + stateUpdater.add(task); + + final Set<StreamTask> expectedRestoredTasks = Collections.singleton(task); + final Set<StreamTask> restoredTasks = new HashSet<>(); + waitForCondition( + () -> { + restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT))); + return restoredTasks.size() == expectedRestoredTasks.size(); + }, + VERIFICATION_TIMEOUT, + "Did not get any restored active task within the given timeout!" + ); + assertTrue(restoredTasks.containsAll(expectedRestoredTasks)); + assertEquals(expectedRestoredTasks.size(), restoredTasks.stream().filter(t -> t.state() == State.RESTORING).count()); + assertTrue(stateUpdater.getAllTasks().isEmpty()); + verify(changelogReader, atLeast(3)).restore(anyMap()); + verify(task).completeRestoration(offsetResetter); + } + + @Test + public void shouldRestoreMultipleActiveStatefulTasks() throws Exception { + final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); + final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0)); + final StreamTask task3 = createActiveStatefulTaskInStateRestoring(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0)); + when(changelogReader.completedChangelogs()) + .thenReturn(Collections.emptySet()) + .thenReturn(mkSet(TOPIC_PARTITION_C_0)) + .thenReturn(mkSet(TOPIC_PARTITION_C_0, TOPIC_PARTITION_A_0)) + .thenReturn(mkSet(TOPIC_PARTITION_C_0, TOPIC_PARTITION_A_0, TOPIC_PARTITION_B_0)); + when(changelogReader.allChangelogsCompleted()) + .thenReturn(false) + .thenReturn(false) + .thenReturn(false) + .thenReturn(true); + + stateUpdater.add(task1); + stateUpdater.add(task2); + stateUpdater.add(task3); + + final Set<StreamTask> expectedRestoredTasks = mkSet(task3, task1, task2); + final Set<StreamTask> restoredTasks = new HashSet<>(); + waitForCondition( + () -> { + restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT))); + return restoredTasks.size() == expectedRestoredTasks.size(); + }, + VERIFICATION_TIMEOUT, + "Did not get any restored active task within the given timeout!" + ); + assertTrue(restoredTasks.containsAll(expectedRestoredTasks)); + assertEquals(expectedRestoredTasks.size(), restoredTasks.stream().filter(t -> t.state() == State.RESTORING).count()); + assertTrue(stateUpdater.getAllTasks().isEmpty()); + verify(changelogReader, atLeast(4)).restore(anyMap()); + verify(task3).completeRestoration(offsetResetter); + verify(task1).completeRestoration(offsetResetter); + verify(task2).completeRestoration(offsetResetter); + } + + @Test + public void shouldAddFailedTasksToQueueWhenRestoreThrowsStreamsExceptionWithoutTask() throws Exception { + final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); + final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0)); + final StreamTask task3 = createActiveStatefulTaskInStateRestoring(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0)); + final String expectedMessage = "The Streams were crossed!"; + final StreamsException expectedStreamsException = new StreamsException(expectedMessage); + final Map<TaskId, Task> updatingTasks = mkMap( + mkEntry(task1.id(), task1), + mkEntry(task2.id(), task2), + mkEntry(task3.id(), task3) + ); + doNothing().doThrow(expectedStreamsException).doNothing().when(changelogReader).restore(updatingTasks); + + stateUpdater.add(task1); + stateUpdater.add(task2); + stateUpdater.add(task3); + + final List<ExceptionAndTasks> failedTasks = getFailedTasks(1); + assertEquals(1, failedTasks.size()); + final ExceptionAndTasks actualFailedTasks = failedTasks.get(0); + assertEquals(3, actualFailedTasks.tasks.size()); + assertTrue(actualFailedTasks.tasks.containsAll(Arrays.asList(task1, task2, task3))); + assertTrue(actualFailedTasks.exception instanceof StreamsException); + final StreamsException actualException = (StreamsException) actualFailedTasks.exception; + assertFalse(actualException.taskId().isPresent()); + assertEquals(expectedMessage, actualException.getMessage()); + assertTrue(stateUpdater.getAllTasks().isEmpty()); + } + + @Test + public void shouldAddFailedTasksToQueueWhenRestoreThrowsStreamsExceptionWithTask() throws Exception { + final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); + final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0)); + final StreamTask task3 = createActiveStatefulTaskInStateRestoring(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0)); + final String expectedMessage = "The Streams were crossed!"; + final StreamsException expectedStreamsException1 = new StreamsException(expectedMessage, task1.id()); + final StreamsException expectedStreamsException2 = new StreamsException(expectedMessage, task3.id()); + final Map<TaskId, Task> updatingTasksBeforeFirstThrow = mkMap( + mkEntry(task1.id(), task1), + mkEntry(task2.id(), task2), + mkEntry(task3.id(), task3) + ); + final Map<TaskId, Task> updatingTasksBeforeSecondThrow = mkMap( + mkEntry(task2.id(), task2), + mkEntry(task3.id(), task3) + ); + doNothing() + .doThrow(expectedStreamsException1) + .when(changelogReader).restore(updatingTasksBeforeFirstThrow); + doNothing() + .doThrow(expectedStreamsException2) + .when(changelogReader).restore(updatingTasksBeforeSecondThrow); + + stateUpdater.add(task1); + stateUpdater.add(task2); + stateUpdater.add(task3); + + final List<ExceptionAndTasks> failedTasks = getFailedTasks(2); + assertEquals(2, failedTasks.size()); + final ExceptionAndTasks actualFailedTasks1 = failedTasks.get(0); + assertEquals(1, actualFailedTasks1.tasks.size()); + assertTrue(actualFailedTasks1.tasks.contains(task1)); + assertTrue(actualFailedTasks1.exception instanceof StreamsException); + final StreamsException actualException1 = (StreamsException) actualFailedTasks1.exception; + assertTrue(actualException1.taskId().isPresent()); + assertEquals(task1.id(), actualException1.taskId().get()); + assertEquals(expectedMessage, actualException1.getMessage()); + final ExceptionAndTasks actualFailedTasks2 = failedTasks.get(1); + assertEquals(1, actualFailedTasks2.tasks.size()); + assertTrue(actualFailedTasks2.tasks.contains(task3)); + assertTrue(actualFailedTasks2.exception instanceof StreamsException); + final StreamsException actualException2 = (StreamsException) actualFailedTasks2.exception; + assertTrue(actualException2.taskId().isPresent()); + assertEquals(task3.id(), actualException2.taskId().get()); + assertEquals(expectedMessage, actualException2.getMessage()); + assertEquals(1, stateUpdater.getAllTasks().size()); + assertTrue(stateUpdater.getAllTasks().contains(task2)); + } + + @Test + public void shouldAddFailedTasksToQueueWhenRestoreThrowsTaskCorruptedException() throws Exception { + final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); + final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0)); + final StreamTask task3 = createActiveStatefulTaskInStateRestoring(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0)); + final Set<TaskId> expectedTaskIds = mkSet(task1.id(), task2.id()); + final TaskCorruptedException taskCorruptedException = new TaskCorruptedException(expectedTaskIds); + final Map<TaskId, Task> updatingTasks = mkMap( + mkEntry(task1.id(), task1), + mkEntry(task2.id(), task2), + mkEntry(task3.id(), task3) + ); + doNothing().doThrow(taskCorruptedException).doNothing().when(changelogReader).restore(updatingTasks); + + stateUpdater.add(task1); + stateUpdater.add(task2); + stateUpdater.add(task3); + + final List<ExceptionAndTasks> failedTasks = getFailedTasks(1); + assertEquals(1, failedTasks.size()); + final List<Task> expectedFailedTasks = Arrays.asList(task1, task2); + final ExceptionAndTasks actualFailedTasks = failedTasks.get(0); + assertEquals(2, actualFailedTasks.tasks.size()); + assertTrue(actualFailedTasks.tasks.containsAll(expectedFailedTasks)); + assertTrue(actualFailedTasks.exception instanceof TaskCorruptedException); + final TaskCorruptedException actualException = (TaskCorruptedException) actualFailedTasks.exception; + final Set<TaskId> corruptedTasks = actualException.corruptedTasks(); + assertTrue(corruptedTasks.containsAll(expectedFailedTasks.stream().map(Task::id).collect(Collectors.toList()))); + assertEquals(1, stateUpdater.getAllTasks().size()); + assertTrue(stateUpdater.getAllTasks().contains(task3)); + } + + @Test + public void shouldAddFailedTasksToQueueWhenUncaughtExceptionIsThrown() throws Exception { + final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); + final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0)); + final IllegalStateException illegalStateException = new IllegalStateException("Nobody expects the Spanish inquisition!"); + final Map<TaskId, Task> updatingTasks = mkMap( + mkEntry(task1.id(), task1), + mkEntry(task2.id(), task2) + ); + doThrow(illegalStateException).when(changelogReader).restore(updatingTasks); + + stateUpdater.add(task1); + stateUpdater.add(task2); + + final List<ExceptionAndTasks> failedTasks = getFailedTasks(1); + final List<Task> expectedFailedTasks = Arrays.asList(task1, task2); + final ExceptionAndTasks actualFailedTasks = failedTasks.get(0); + assertEquals(2, actualFailedTasks.tasks.size()); + assertTrue(actualFailedTasks.tasks.containsAll(expectedFailedTasks)); + assertTrue(actualFailedTasks.exception instanceof IllegalStateException); + final IllegalStateException actualException = (IllegalStateException) actualFailedTasks.exception; + assertEquals(actualException.getMessage(), illegalStateException.getMessage()); + assertTrue(stateUpdater.getAllTasks().isEmpty()); + } + + private List<ExceptionAndTasks> getFailedTasks(final int expectedCount) throws Exception { + final List<ExceptionAndTasks> failedTasks = new ArrayList<>(); + waitForCondition( + () -> { + failedTasks.addAll(stateUpdater.getFailedTasksAndExceptions()); + return failedTasks.size() >= expectedCount; + }, + VERIFICATION_TIMEOUT, + "Did not get enough failed tasks within the given timeout!" + ); + + return failedTasks; + } + + private StreamTask createActiveStatefulTaskInStateRestoring(final TaskId taskId, + final Collection<TopicPartition> changelogPartitions) { + final StreamTask task = createActiveStatefulTask(taskId, changelogPartitions); + when(task.state()).thenReturn(State.RESTORING); + return task; + } + + private StreamTask createActiveStatefulTask(final TaskId taskId, + final Collection<TopicPartition> changelogPartitions) { + final StreamTask task = mock(StreamTask.class); + setupStatefulTask(task, taskId, changelogPartitions); + when(task.isActive()).thenReturn(true); + return task; + } + + private StreamTask createStatelessTaskInStateRestoring(final TaskId taskId) { + final StreamTask task = createStatelessTask(taskId); + when(task.state()).thenReturn(State.RESTORING); + return task; + } + + private StreamTask createStatelessTask(final TaskId taskId) { + final StreamTask task = mock(StreamTask.class); + when(task.changelogPartitions()).thenReturn(Collections.emptySet()); + when(task.isActive()).thenReturn(true); + when(task.id()).thenReturn(taskId); + return task; + } + + private void setupStatefulTask(final Task task, + final TaskId taskId, + final Collection<TopicPartition> changelogPartitions) { + when(task.changelogPartitions()).thenReturn(changelogPartitions); + when(task.id()).thenReturn(taskId); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java index 6ea7fc3101..d86728891c 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java @@ -59,6 +59,11 @@ public class MockChangelogReader implements ChangelogReader { return restoringPartitions; } + @Override + public boolean allChangelogsCompleted() { + return false; + } + @Override public void clear() { restoringPartitions.clear();