This is an automated email from the ASF dual-hosted git repository. guozhang pushed a commit to branch trunk in repository https://gitbox.apache.org/repos/asf/kafka.git
The following commit(s) were added to refs/heads/trunk by this push: new e67408c859 KAFKA-10199: Implement removing active and standby tasks from the state updater (#12270) e67408c859 is described below commit e67408c859fb2a80f1b3c208b7fef6ddc9a711fb Author: Bruno Cadonna <cado...@apache.org> AuthorDate: Thu Jun 9 19:28:26 2022 +0200 KAFKA-10199: Implement removing active and standby tasks from the state updater (#12270) This PR adds removing of active and standby tasks from the default implementation of the state updater. The PR also includes refactoring that clean up the code. Reviewers: Guozhang Wang <wangg...@gmail.com> --- .../processor/internals/DefaultStateUpdater.java | 129 ++++--- .../streams/processor/internals/StateUpdater.java | 78 ++-- .../streams/processor/internals/TaskAndAction.java | 67 ++++ .../internals/DefaultStateUpdaterTest.java | 419 +++++++++++++++++---- .../processor/internals/TaskAndActionTest.java | 68 ++++ 5 files changed, 595 insertions(+), 166 deletions(-) diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java index 55935d3e21..54cb7bc427 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java @@ -23,13 +23,13 @@ import org.apache.kafka.streams.errors.StreamsException; import org.apache.kafka.streams.errors.TaskCorruptedException; import org.apache.kafka.streams.processor.TaskId; import org.apache.kafka.streams.processor.internals.Task.State; +import org.apache.kafka.streams.processor.internals.TaskAndAction.Action; import org.slf4j.Logger; import java.time.Duration; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.List; @@ -37,6 +37,7 @@ import java.util.Map; import java.util.Queue; import java.util.Set; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -57,7 +58,7 @@ public class DefaultStateUpdater implements StateUpdater { private final ChangelogReader changelogReader; private final AtomicBoolean isRunning = new AtomicBoolean(true); private final Consumer<Set<TopicPartition>> offsetResetter; - private final Map<TaskId, Task> updatingTasks = new HashMap<>(); + private final Map<TaskId, Task> updatingTasks = new ConcurrentHashMap<>(); private final Logger log; public StateUpdaterThread(final String name, @@ -72,7 +73,7 @@ public class DefaultStateUpdater implements StateUpdater { log = logContext.logger(DefaultStateUpdater.class); } - public Collection<Task> getAllUpdatingTasks() { + public Collection<Task> getUpdatingTasks() { return updatingTasks.values(); } @@ -117,11 +118,13 @@ public class DefaultStateUpdater implements StateUpdater { tasksAndActionsLock.lock(); try { for (final TaskAndAction taskAndAction : getTasksAndActions()) { - final Task task = taskAndAction.task; - final Action action = taskAndAction.action; + final Action action = taskAndAction.getAction(); switch (action) { case ADD: - addTask(task); + addTask(taskAndAction.getTask()); + break; + case REMOVE: + removeTask(taskAndAction.getTaskId()); break; } } @@ -149,7 +152,7 @@ public class DefaultStateUpdater implements StateUpdater { log.error("An unexpected error occurred within the state updater thread: " + runtimeException); final ExceptionAndTasks exceptionAndTasks = new ExceptionAndTasks(new HashSet<>(updatingTasks.values()), runtimeException); updatingTasks.clear(); - failedTasks.add(exceptionAndTasks); + exceptionsAndFailedTasks.add(exceptionAndTasks); isRunning.set(false); } @@ -164,7 +167,7 @@ public class DefaultStateUpdater implements StateUpdater { } corruptedTasks.add(corruptedTask); } - failedTasks.add(new ExceptionAndTasks(corruptedTasks, taskCorruptedException)); + exceptionsAndFailedTasks.add(new ExceptionAndTasks(corruptedTasks, taskCorruptedException)); } private void handleStreamsException(final StreamsException streamsException) { @@ -175,7 +178,7 @@ public class DefaultStateUpdater implements StateUpdater { } else { exceptionAndTasks = handleStreamsExceptionWithoutTask(streamsException); } - failedTasks.add(exceptionAndTasks); + exceptionsAndFailedTasks.add(exceptionAndTasks); } private ExceptionAndTasks handleStreamsExceptionWithTask(final StreamsException streamsException) { @@ -230,16 +233,15 @@ public class DefaultStateUpdater implements StateUpdater { private void addTask(final Task task) { if (isStateless(task)) { - log.debug("Stateless active task " + task.id() + " was added to the state updater"); addTaskToRestoredTasks((StreamTask) task); + log.debug("Stateless active task " + task.id() + " was added to the restored tasks of the state updater"); } else { + updatingTasks.put(task.id(), task); if (task.isActive()) { - updatingTasks.put(task.id(), task); - log.debug("Stateful active task " + task.id() + " was added to the state updater"); + log.debug("Stateful active task " + task.id() + " was added to the updating tasks of the state updater"); changelogReader.enforceRestoreActive(); } else { - updatingTasks.put(task.id(), task); - log.debug("Standby task " + task.id() + " was added to the state updater"); + log.debug("Standby task " + task.id() + " was added to the updating tasks of the state updater"); if (updatingTasks.size() == 1) { changelogReader.transitToUpdateStandby(); } @@ -247,6 +249,19 @@ public class DefaultStateUpdater implements StateUpdater { } } + private void removeTask(final TaskId taskId) { + final Task task = updatingTasks.remove(taskId); + if (task != null) { + final Collection<TopicPartition> changelogPartitions = task.changelogPartitions(); + changelogReader.unregister(changelogPartitions); + removedTasks.add(task); + log.debug((task.isActive() ? "Active" : "Standby") + + " task " + task.id() + " was removed from the updating tasks and added to the removed tasks."); + } else { + log.debug("Task " + taskId + " was not removed since it is not updating."); + } + } + private boolean isStateless(final Task task) { return task.changelogPartitions().isEmpty() && task.isActive(); } @@ -277,20 +292,6 @@ public class DefaultStateUpdater implements StateUpdater { } } - enum Action { - ADD - } - - private static class TaskAndAction { - public final Task task; - public final Action action; - - public TaskAndAction(final Task task, final Action action) { - this.task = task; - this.action = action; - } - } - private final Time time; private final ChangelogReader changelogReader; private final Consumer<Set<TopicPartition>> offsetResetter; @@ -300,7 +301,8 @@ public class DefaultStateUpdater implements StateUpdater { private final Queue<StreamTask> restoredActiveTasks = new LinkedList<>(); private final Lock restoredActiveTasksLock = new ReentrantLock(); private final Condition restoredActiveTasksCondition = restoredActiveTasksLock.newCondition(); - private final BlockingQueue<ExceptionAndTasks> failedTasks = new LinkedBlockingQueue<>(); + private final BlockingQueue<ExceptionAndTasks> exceptionsAndFailedTasks = new LinkedBlockingQueue<>(); + private final BlockingQueue<Task> removedTasks = new LinkedBlockingQueue<>(); private CountDownLatch shutdownGate; private StateUpdaterThread stateUpdaterThread = null; @@ -325,7 +327,7 @@ public class DefaultStateUpdater implements StateUpdater { tasksAndActionsLock.lock(); try { - tasksAndActions.add(new TaskAndAction(task, Action.ADD)); + tasksAndActions.add(TaskAndAction.createAddTask(task)); tasksAndActionsCondition.signalAll(); } finally { tasksAndActionsLock.unlock(); @@ -342,11 +344,18 @@ public class DefaultStateUpdater implements StateUpdater { } @Override - public void remove(final Task task) { + public void remove(final TaskId taskId) { + tasksAndActionsLock.lock(); + try { + tasksAndActions.add(TaskAndAction.createRemoveTask(taskId)); + tasksAndActionsCondition.signalAll(); + } finally { + tasksAndActionsLock.unlock(); + } } @Override - public Set<StreamTask> getRestoredActiveTasks(final Duration timeout) { + public Set<StreamTask> drainRestoredActiveTasks(final Duration timeout) { final long timeoutMs = timeout.toMillis(); final long startTime = time.milliseconds(); final long deadline = startTime + timeoutMs; @@ -375,52 +384,42 @@ public class DefaultStateUpdater implements StateUpdater { } @Override - public List<ExceptionAndTasks> getFailedTasksAndExceptions() { + public Set<Task> drainRemovedTasks() { + final List<Task> result = new ArrayList<>(); + removedTasks.drainTo(result); + return new HashSet<>(result); + } + + @Override + public List<ExceptionAndTasks> drainExceptionsAndFailedTasks() { final List<ExceptionAndTasks> result = new ArrayList<>(); - failedTasks.drainTo(result); + exceptionsAndFailedTasks.drainTo(result); return result; } - @Override - public Set<Task> getAllTasks() { - tasksAndActionsLock.lock(); + public Set<StandbyTask> getUpdatingStandbyTasks() { + return Collections.unmodifiableSet(new HashSet<>(stateUpdaterThread.getUpdatingStandbyTasks())); + } + + public Set<Task> getUpdatingTasks() { + return Collections.unmodifiableSet(new HashSet<>(stateUpdaterThread.getUpdatingTasks())); + } + + public Set<StreamTask> getRestoredActiveTasks() { restoredActiveTasksLock.lock(); try { - final Set<Task> allTasks = new HashSet<>(); - allTasks.addAll(tasksAndActions.stream() - .filter(t -> t.action == Action.ADD) - .map(t -> t.task) - .collect(Collectors.toList()) - ); - allTasks.addAll(stateUpdaterThread.getAllUpdatingTasks()); - allTasks.addAll(restoredActiveTasks); - return Collections.unmodifiableSet(allTasks); + return Collections.unmodifiableSet(new HashSet<>(restoredActiveTasks)); } finally { restoredActiveTasksLock.unlock(); - tasksAndActionsLock.unlock(); } } - @Override - public Set<StandbyTask> getStandbyTasks() { - tasksAndActionsLock.lock(); - try { - final Set<StandbyTask> standbyTasks = new HashSet<>(); - standbyTasks.addAll(tasksAndActions.stream() - .filter(t -> t.action == Action.ADD) - .filter(t -> !t.task.isActive()) - .map(t -> (StandbyTask) t.task) - .collect(Collectors.toList()) - ); - standbyTasks.addAll(getUpdatingStandbyTasks()); - return Collections.unmodifiableSet(standbyTasks); - } finally { - tasksAndActionsLock.unlock(); - } + public List<ExceptionAndTasks> getExceptionsAndFailedTasks() { + return Collections.unmodifiableList(new ArrayList<>(exceptionsAndFailedTasks)); } - public Set<StandbyTask> getUpdatingStandbyTasks() { - return Collections.unmodifiableSet(new HashSet<>(stateUpdaterThread.getUpdatingStandbyTasks())); + public Set<Task> getRemovedTasks() { + return Collections.unmodifiableSet(new HashSet<>(removedTasks)); } @Override diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java index d2d4ab71ad..42e65d4adb 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java @@ -16,65 +16,101 @@ */ package org.apache.kafka.streams.processor.internals; +import org.apache.kafka.streams.processor.TaskId; + import java.time.Duration; +import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.Set; public interface StateUpdater { class ExceptionAndTasks { - public final Set<Task> tasks; - public final RuntimeException exception; + private final Set<Task> tasks; + private final RuntimeException exception; public ExceptionAndTasks(final Set<Task> tasks, final RuntimeException exception) { - this.tasks = tasks; - this.exception = exception; + this.tasks = Objects.requireNonNull(tasks); + this.exception = Objects.requireNonNull(exception); + } + + public Set<Task> tasks() { + return Collections.unmodifiableSet(tasks); + } + + public RuntimeException exception() { + return exception; + } + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (!(o instanceof ExceptionAndTasks)) return false; + final ExceptionAndTasks that = (ExceptionAndTasks) o; + return tasks.equals(that.tasks) && exception.equals(that.exception); + } + + @Override + public int hashCode() { + return Objects.hash(tasks, exception); } } /** * Adds a task (active or standby) to the state updater. * + * This method does not block until the task is added to the state updater. + * * @param task task to add */ void add(final Task task); /** - * Removes a task (active or standby) from the state updater. + * Removes a task (active or standby) from the state updater and adds the removed task to the removed tasks. + * + * This method does not block until the removed task is removed from the state updater. * - * @param task task ro remove + * The task to be removed is not removed from the restored active tasks and the failed tasks. + * Stateless tasks will never be added to the removed tasks since they are immediately added to the + * restored active tasks. + * + * @param taskId ID of the task to remove */ - void remove(final Task task); + void remove(final TaskId taskId); /** - * Gets restored active tasks from state restoration/update + * Drains the restored active tasks from the state updater. + * + * The returned active tasks are removed from the state updater. * * @param timeout duration how long the calling thread should wait for restored active tasks * * @return set of active tasks with up-to-date states */ - Set<StreamTask> getRestoredActiveTasks(final Duration timeout); + Set<StreamTask> drainRestoredActiveTasks(final Duration timeout); - /** - * Gets failed tasks and the corresponding exceptions - * - * @return list of failed tasks and the corresponding exceptions - */ - List<ExceptionAndTasks> getFailedTasksAndExceptions(); /** - * Get all tasks (active and standby) that are managed by the state updater. + * Drains the removed tasks (active and standbys) from the state updater. + * + * Removed tasks returned by this method are tasks extraordinarily removed from the state updater. These do not + * include restored or failed tasks. + * + * The returned removed tasks are removed from the state updater * - * @return set of tasks managed by the state updater + * @return set of tasks removed from the state updater */ - Set<Task> getAllTasks(); + Set<Task> drainRemovedTasks(); /** - * Get standby tasks that are managed by the state updater. + * Drains the failed tasks and the corresponding exceptions. * - * @return set of standby tasks managed by the state updater + * The returned failed tasks are removed from the state updater + * + * @return list of failed tasks and the corresponding exceptions */ - Set<StandbyTask> getStandbyTasks(); + List<ExceptionAndTasks> drainExceptionsAndFailedTasks(); /** * Shuts down the state updater. diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAndAction.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAndAction.java new file mode 100644 index 0000000000..4c4316a864 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAndAction.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.streams.processor.TaskId; + +import java.util.Objects; + +public class TaskAndAction { + + enum Action { + ADD, + REMOVE + } + + private final Task task; + private final TaskId taskId; + private final Action action; + + private TaskAndAction(final Task task, final TaskId taskId, final Action action) { + this.task = task; + this.taskId = taskId; + this.action = action; + } + + public static TaskAndAction createAddTask(final Task task) { + Objects.requireNonNull(task, "Task to add is null!"); + return new TaskAndAction(task, null, Action.ADD); + } + + public static TaskAndAction createRemoveTask(final TaskId taskId) { + Objects.requireNonNull(taskId, "Task ID of task to remove is null!"); + return new TaskAndAction(null, taskId, Action.REMOVE); + } + + public Task getTask() { + if (action != Action.ADD) { + throw new IllegalStateException("Action type " + action + " cannot have a task!"); + } + return task; + } + + public TaskId getTaskId() { + if (action != Action.REMOVE) { + throw new IllegalStateException("Action type " + action + " cannot have a task ID!"); + } + return taskId; + } + + public Action getAction() { + return action; + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java index c9fa1abede..fa50380f7f 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java @@ -36,14 +36,12 @@ import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; import static org.apache.kafka.common.utils.Utils.mkEntry; import static org.apache.kafka.common.utils.Utils.mkMap; import static org.apache.kafka.common.utils.Utils.mkSet; import static org.apache.kafka.test.TestUtils.waitForCondition; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.anyMap; @@ -154,7 +152,9 @@ class DefaultStateUpdaterTest { } verifyRestoredActiveTasks(tasks); - assertTrue(stateUpdater.getAllTasks().isEmpty()); + verifyUpdatingTasks(); + verifyExceptionsAndFailedTasks(); + verifyRemovedTasks(); } @Test @@ -173,9 +173,11 @@ class DefaultStateUpdaterTest { stateUpdater.add(task); verifyRestoredActiveTasks(task); - assertTrue(stateUpdater.getAllTasks().isEmpty()); + verifyUpdatingTasks(); + verifyExceptionsAndFailedTasks(); + verifyRemovedTasks(); verify(changelogReader, times(1)).enforceRestoreActive(); - verify(changelogReader, atLeast(1)).restore(anyMap()); + verify(changelogReader, atLeast(3)).restore(anyMap()); verify(task).completeRestoration(offsetResetter); verify(changelogReader, never()).transitToUpdateStandby(); } @@ -201,7 +203,9 @@ class DefaultStateUpdaterTest { stateUpdater.add(task3); verifyRestoredActiveTasks(task3, task1, task2); - assertTrue(stateUpdater.getAllTasks().isEmpty()); + verifyUpdatingTasks(); + verifyExceptionsAndFailedTasks(); + verifyRemovedTasks(); verify(changelogReader, times(3)).enforceRestoreActive(); verify(changelogReader, atLeast(4)).restore(anyMap()); verify(task3).completeRestoration(offsetResetter); @@ -210,6 +214,25 @@ class DefaultStateUpdaterTest { verify(changelogReader, never()).transitToUpdateStandby(); } + @Test + public void shouldDrainRestoredActiveTasks() throws Exception { + assertTrue(stateUpdater.drainRestoredActiveTasks(Duration.ZERO).isEmpty()); + + final StreamTask task1 = createStatelessTaskInStateRestoring(TASK_0_0); + stateUpdater.add(task1); + + verifyDrainingRestoredActiveTasks(task1); + + final StreamTask task2 = createStatelessTaskInStateRestoring(TASK_1_1); + final StreamTask task3 = createStatelessTaskInStateRestoring(TASK_1_0); + final StreamTask task4 = createStatelessTaskInStateRestoring(TASK_0_2); + stateUpdater.add(task2); + stateUpdater.add(task3); + stateUpdater.add(task4); + + verifyDrainingRestoredActiveTasks(task2, task3, task4); + } + @Test public void shouldUpdateSingleStandbyTask() throws Exception { final StandbyTask task = createStandbyTaskInStateRunning( @@ -236,6 +259,9 @@ class DefaultStateUpdaterTest { } verifyUpdatingStandbyTasks(tasks); + verifyRestoredActiveTasks(); + verifyExceptionsAndFailedTasks(); + verifyRemovedTasks(); verify(changelogReader, times(1)).transitToUpdateStandby(); verify(changelogReader, timeout(VERIFICATION_TIMEOUT).atLeast(1)).restore(anyMap()); verify(changelogReader, never()).enforceRestoreActive(); @@ -260,10 +286,12 @@ class DefaultStateUpdaterTest { stateUpdater.add(task4); verifyRestoredActiveTasks(task2, task1); + verifyUpdatingStandbyTasks(task4, task3); + verifyExceptionsAndFailedTasks(); + verifyRemovedTasks(); verify(task1).completeRestoration(offsetResetter); verify(task2).completeRestoration(offsetResetter); verify(changelogReader, atLeast(3)).restore(anyMap()); - verifyUpdatingStandbyTasks(task4, task3); final InOrder orderVerifier = inOrder(changelogReader, task1, task2); orderVerifier.verify(changelogReader, times(2)).enforceRestoreActive(); orderVerifier.verify(changelogReader, times(1)).transitToUpdateStandby(); @@ -293,37 +321,155 @@ class DefaultStateUpdaterTest { stateUpdater.add(task3); - verifyRestoredActiveTasks(task3); + verifyRestoredActiveTasks(task1, task3); verify(task3).completeRestoration(offsetResetter); orderVerifier.verify(changelogReader, times(1)).enforceRestoreActive(); orderVerifier.verify(changelogReader, times(1)).transitToUpdateStandby(); } + @Test + public void shouldRemoveActiveStatefulTask() throws Exception { + final StreamTask task = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); + shouldRemoveStatefulTask(task); + } + + @Test + public void shouldRemoveStandbyTask() throws Exception { + final StandbyTask task = createStandbyTaskInStateRunning(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); + shouldRemoveStatefulTask(task); + } + + private void shouldRemoveStatefulTask(final Task task) throws Exception { + when(changelogReader.completedChangelogs()) + .thenReturn(Collections.emptySet()); + when(changelogReader.allChangelogsCompleted()) + .thenReturn(false); + stateUpdater.add(task); + + stateUpdater.remove(TASK_0_0); + + verifyRemovedTasks(task); + verifyRestoredActiveTasks(); + verifyUpdatingTasks(); + verifyExceptionsAndFailedTasks(); + verify(changelogReader).unregister(Collections.singletonList(TOPIC_PARTITION_A_0)); + } + + @Test + public void shouldNotRemoveActiveStatefulTaskFromRestoredActiveTasks() throws Exception { + final StreamTask task = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); + shouldNotRemoveTaskFromRestoredActiveTasks(task); + } + + @Test + public void shouldNotRemoveStatelessTaskFromRestoredActiveTasks() throws Exception { + final StreamTask task = createStatelessTaskInStateRestoring(TASK_0_0); + shouldNotRemoveTaskFromRestoredActiveTasks(task); + } + + private void shouldNotRemoveTaskFromRestoredActiveTasks(final StreamTask task) throws Exception { + final StreamTask controlTask = createActiveStatefulTaskInStateRestoring(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_B_0)); + when(changelogReader.completedChangelogs()) + .thenReturn(Collections.singleton(TOPIC_PARTITION_A_0)); + when(changelogReader.allChangelogsCompleted()) + .thenReturn(false); + stateUpdater.add(task); + stateUpdater.add(controlTask); + verifyRestoredActiveTasks(task); + + stateUpdater.remove(task.id()); + stateUpdater.remove(controlTask.id()); + + verifyRemovedTasks(controlTask); + verifyRestoredActiveTasks(task); + verifyUpdatingTasks(); + verifyExceptionsAndFailedTasks(); + } + + @Test + public void shouldNotRemoveActiveStatefulTaskFromFailedTasks() throws Exception { + final StreamTask task = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); + shouldNotRemoveTaskFromFailedTasks(task); + } + + @Test + public void shouldNotRemoveStandbyTaskFromFailedTasks() throws Exception { + final StandbyTask task = createStandbyTaskInStateRunning(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); + shouldNotRemoveTaskFromFailedTasks(task); + } + + private void shouldNotRemoveTaskFromFailedTasks(final Task task) throws Exception { + final StreamTask controlTask = createActiveStatefulTaskInStateRestoring(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_B_0)); + final StreamsException streamsException = new StreamsException("Something happened", task.id()); + when(changelogReader.completedChangelogs()) + .thenReturn(Collections.emptySet()); + when(changelogReader.allChangelogsCompleted()) + .thenReturn(false); + doNothing() + .doThrow(streamsException) + .doNothing() + .when(changelogReader).restore(anyMap()); + stateUpdater.add(task); + stateUpdater.add(controlTask); + final ExceptionAndTasks expectedExceptionAndTasks = new ExceptionAndTasks(mkSet(task), streamsException); + verifyExceptionsAndFailedTasks(expectedExceptionAndTasks); + + stateUpdater.remove(task.id()); + stateUpdater.remove(controlTask.id()); + + verifyRemovedTasks(controlTask); + verifyExceptionsAndFailedTasks(expectedExceptionAndTasks); + verifyUpdatingTasks(); + verifyRestoredActiveTasks(); + } + + @Test + public void shouldDrainRemovedTasks() throws Exception { + assertTrue(stateUpdater.drainRemovedTasks().isEmpty()); + when(changelogReader.completedChangelogs()) + .thenReturn(Collections.emptySet()); + when(changelogReader.allChangelogsCompleted()) + .thenReturn(false); + + final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_B_0)); + stateUpdater.add(task1); + stateUpdater.remove(task1.id()); + + verifyDrainingRemovedTasks(task1); + + final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_1_1, Collections.singletonList(TOPIC_PARTITION_C_0)); + final StreamTask task3 = createActiveStatefulTaskInStateRestoring(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_A_0)); + final StreamTask task4 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_D_0)); + stateUpdater.add(task2); + stateUpdater.remove(task2.id()); + stateUpdater.add(task3); + stateUpdater.remove(task3.id()); + stateUpdater.add(task4); + stateUpdater.remove(task4.id()); + + verifyDrainingRemovedTasks(task2, task3, task4); + } + @Test public void shouldAddFailedTasksToQueueWhenRestoreThrowsStreamsExceptionWithoutTask() throws Exception { final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); final StandbyTask task2 = createStandbyTaskInStateRunning(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0)); - final String expectedMessage = "The Streams were crossed!"; - final StreamsException expectedStreamsException = new StreamsException(expectedMessage); + final String exceptionMessage = "The Streams were crossed!"; + final StreamsException streamsException = new StreamsException(exceptionMessage); final Map<TaskId, Task> updatingTasks = mkMap( mkEntry(task1.id(), task1), mkEntry(task2.id(), task2) ); - doNothing().doThrow(expectedStreamsException).doNothing().when(changelogReader).restore(updatingTasks); + doNothing().doThrow(streamsException).when(changelogReader).restore(updatingTasks); stateUpdater.add(task1); stateUpdater.add(task2); - final List<ExceptionAndTasks> failedTasks = getFailedTasks(1); - assertEquals(1, failedTasks.size()); - final ExceptionAndTasks actualFailedTasks = failedTasks.get(0); - assertEquals(2, actualFailedTasks.tasks.size()); - assertTrue(actualFailedTasks.tasks.containsAll(Arrays.asList(task1, task2))); - assertTrue(actualFailedTasks.exception instanceof StreamsException); - final StreamsException actualException = (StreamsException) actualFailedTasks.exception; - assertFalse(actualException.taskId().isPresent()); - assertEquals(expectedMessage, actualException.getMessage()); - assertTrue(stateUpdater.getAllTasks().isEmpty()); + final ExceptionAndTasks expectedExceptionAndTasks = new ExceptionAndTasks(mkSet(task1, task2), streamsException); + verifyExceptionsAndFailedTasks(expectedExceptionAndTasks); + verifyRemovedTasks(); + verifyUpdatingTasks(); + verifyRestoredActiveTasks(); } @Test @@ -331,9 +477,9 @@ class DefaultStateUpdaterTest { final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0)); final StandbyTask task3 = createStandbyTaskInStateRunning(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0)); - final String expectedMessage = "The Streams were crossed!"; - final StreamsException expectedStreamsException1 = new StreamsException(expectedMessage, task1.id()); - final StreamsException expectedStreamsException2 = new StreamsException(expectedMessage, task3.id()); + final String exceptionMessage = "The Streams were crossed!"; + final StreamsException streamsException1 = new StreamsException(exceptionMessage, task1.id()); + final StreamsException streamsException2 = new StreamsException(exceptionMessage, task3.id()); final Map<TaskId, Task> updatingTasksBeforeFirstThrow = mkMap( mkEntry(task1.id(), task1), mkEntry(task2.id(), task2), @@ -344,36 +490,22 @@ class DefaultStateUpdaterTest { mkEntry(task3.id(), task3) ); doNothing() - .doThrow(expectedStreamsException1) + .doThrow(streamsException1) .when(changelogReader).restore(updatingTasksBeforeFirstThrow); doNothing() - .doThrow(expectedStreamsException2) + .doThrow(streamsException2) .when(changelogReader).restore(updatingTasksBeforeSecondThrow); stateUpdater.add(task1); stateUpdater.add(task2); stateUpdater.add(task3); - final List<ExceptionAndTasks> failedTasks = getFailedTasks(2); - assertEquals(2, failedTasks.size()); - final ExceptionAndTasks actualFailedTasks1 = failedTasks.get(0); - assertEquals(1, actualFailedTasks1.tasks.size()); - assertTrue(actualFailedTasks1.tasks.contains(task1)); - assertTrue(actualFailedTasks1.exception instanceof StreamsException); - final StreamsException actualException1 = (StreamsException) actualFailedTasks1.exception; - assertTrue(actualException1.taskId().isPresent()); - assertEquals(task1.id(), actualException1.taskId().get()); - assertEquals(expectedMessage, actualException1.getMessage()); - final ExceptionAndTasks actualFailedTasks2 = failedTasks.get(1); - assertEquals(1, actualFailedTasks2.tasks.size()); - assertTrue(actualFailedTasks2.tasks.contains(task3)); - assertTrue(actualFailedTasks2.exception instanceof StreamsException); - final StreamsException actualException2 = (StreamsException) actualFailedTasks2.exception; - assertTrue(actualException2.taskId().isPresent()); - assertEquals(task3.id(), actualException2.taskId().get()); - assertEquals(expectedMessage, actualException2.getMessage()); - assertEquals(1, stateUpdater.getAllTasks().size()); - assertTrue(stateUpdater.getAllTasks().contains(task2)); + final ExceptionAndTasks expectedExceptionAndTasks1 = new ExceptionAndTasks(mkSet(task1), streamsException1); + final ExceptionAndTasks expectedExceptionAndTasks2 = new ExceptionAndTasks(mkSet(task3), streamsException2); + verifyExceptionsAndFailedTasks(expectedExceptionAndTasks1, expectedExceptionAndTasks2); + verifyUpdatingTasks(task2); + verifyRestoredActiveTasks(); + verifyRemovedTasks(); } @Test @@ -394,18 +526,11 @@ class DefaultStateUpdaterTest { stateUpdater.add(task2); stateUpdater.add(task3); - final List<ExceptionAndTasks> failedTasks = getFailedTasks(1); - assertEquals(1, failedTasks.size()); - final List<Task> expectedFailedTasks = Arrays.asList(task1, task2); - final ExceptionAndTasks actualFailedTasks = failedTasks.get(0); - assertEquals(2, actualFailedTasks.tasks.size()); - assertTrue(actualFailedTasks.tasks.containsAll(expectedFailedTasks)); - assertTrue(actualFailedTasks.exception instanceof TaskCorruptedException); - final TaskCorruptedException actualException = (TaskCorruptedException) actualFailedTasks.exception; - final Set<TaskId> corruptedTasks = actualException.corruptedTasks(); - assertTrue(corruptedTasks.containsAll(expectedFailedTasks.stream().map(Task::id).collect(Collectors.toList()))); - assertEquals(1, stateUpdater.getAllTasks().size()); - assertTrue(stateUpdater.getAllTasks().contains(task3)); + final ExceptionAndTasks expectedExceptionAndTasks = new ExceptionAndTasks(mkSet(task1, task2), taskCorruptedException); + verifyExceptionsAndFailedTasks(expectedExceptionAndTasks); + verifyUpdatingTasks(task3); + verifyRestoredActiveTasks(); + verifyRemovedTasks(); } @Test @@ -422,30 +547,113 @@ class DefaultStateUpdaterTest { stateUpdater.add(task1); stateUpdater.add(task2); - final List<ExceptionAndTasks> failedTasks = getFailedTasks(1); - final List<Task> expectedFailedTasks = Arrays.asList(task1, task2); - final ExceptionAndTasks actualFailedTasks = failedTasks.get(0); - assertEquals(2, actualFailedTasks.tasks.size()); - assertTrue(actualFailedTasks.tasks.containsAll(expectedFailedTasks)); - assertTrue(actualFailedTasks.exception instanceof IllegalStateException); - final IllegalStateException actualException = (IllegalStateException) actualFailedTasks.exception; - assertEquals(actualException.getMessage(), illegalStateException.getMessage()); - assertTrue(stateUpdater.getAllTasks().isEmpty()); + final ExceptionAndTasks expectedExceptionAndTasks = new ExceptionAndTasks(mkSet(task1, task2), illegalStateException); + verifyExceptionsAndFailedTasks(expectedExceptionAndTasks); + verifyUpdatingTasks(); + verifyRestoredActiveTasks(); + verifyRemovedTasks(); + } + + @Test + public void shouldDrainFailedTasksAndExceptions() throws Exception { + assertTrue(stateUpdater.drainExceptionsAndFailedTasks().isEmpty()); + + final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_B_0)); + final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_1_1, Collections.singletonList(TOPIC_PARTITION_C_0)); + final StreamTask task3 = createActiveStatefulTaskInStateRestoring(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_A_0)); + final StreamTask task4 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_D_0)); + final String exceptionMessage = "The Streams were crossed!"; + final StreamsException streamsException1 = new StreamsException(exceptionMessage, task1.id()); + final Map<TaskId, Task> updatingTasks1 = mkMap( + mkEntry(task1.id(), task1) + ); + doThrow(streamsException1) + .when(changelogReader).restore(updatingTasks1); + final StreamsException streamsException2 = new StreamsException(exceptionMessage, task2.id()); + final StreamsException streamsException3 = new StreamsException(exceptionMessage, task3.id()); + final StreamsException streamsException4 = new StreamsException(exceptionMessage, task4.id()); + final Map<TaskId, Task> updatingTasks2 = mkMap( + mkEntry(task2.id(), task2), + mkEntry(task3.id(), task3), + mkEntry(task4.id(), task4) + ); + doThrow(streamsException2).when(changelogReader).restore(updatingTasks2); + final Map<TaskId, Task> updatingTasks3 = mkMap( + mkEntry(task3.id(), task3), + mkEntry(task4.id(), task4) + ); + doThrow(streamsException3).when(changelogReader).restore(updatingTasks3); + final Map<TaskId, Task> updatingTasks4 = mkMap( + mkEntry(task4.id(), task4) + ); + doThrow(streamsException4).when(changelogReader).restore(updatingTasks4); + + stateUpdater.add(task1); + + final ExceptionAndTasks expectedExceptionAndTasks1 = new ExceptionAndTasks(mkSet(task1), streamsException1); + verifyDrainingExceptionsAndFailedTasks(expectedExceptionAndTasks1); + + stateUpdater.add(task2); + stateUpdater.add(task3); + stateUpdater.add(task4); + + final ExceptionAndTasks expectedExceptionAndTasks2 = new ExceptionAndTasks(mkSet(task2), streamsException2); + final ExceptionAndTasks expectedExceptionAndTasks3 = new ExceptionAndTasks(mkSet(task3), streamsException3); + final ExceptionAndTasks expectedExceptionAndTasks4 = new ExceptionAndTasks(mkSet(task4), streamsException4); + verifyDrainingExceptionsAndFailedTasks(expectedExceptionAndTasks2, expectedExceptionAndTasks3, expectedExceptionAndTasks4); } private void verifyRestoredActiveTasks(final StreamTask... tasks) throws Exception { + if (tasks.length == 0) { + assertTrue(stateUpdater.getRestoredActiveTasks().isEmpty()); + } else { + final Set<StreamTask> expectedRestoredTasks = mkSet(tasks); + final Set<StreamTask> restoredTasks = new HashSet<>(); + waitForCondition( + () -> { + restoredTasks.addAll(stateUpdater.getRestoredActiveTasks()); + return restoredTasks.containsAll(expectedRestoredTasks); + }, + VERIFICATION_TIMEOUT, + "Did not get all restored active task within the given timeout!" + ); + assertEquals(expectedRestoredTasks.size(), restoredTasks.size()); + assertTrue(restoredTasks.stream().allMatch(task -> task.state() == State.RESTORING)); + } + } + + private void verifyDrainingRestoredActiveTasks(final StreamTask... tasks) throws Exception { final Set<StreamTask> expectedRestoredTasks = mkSet(tasks); final Set<StreamTask> restoredTasks = new HashSet<>(); waitForCondition( () -> { - restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT))); - return restoredTasks.size() == expectedRestoredTasks.size(); + restoredTasks.addAll(stateUpdater.drainRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT))); + return restoredTasks.containsAll(expectedRestoredTasks); }, VERIFICATION_TIMEOUT, - "Did not get any restored active task within the given timeout!" + "Did not get all restored active task within the given timeout!" ); - assertTrue(restoredTasks.containsAll(expectedRestoredTasks)); - assertEquals(expectedRestoredTasks.size(), restoredTasks.stream().filter(task -> task.state() == State.RESTORING).count()); + assertEquals(expectedRestoredTasks.size(), restoredTasks.size()); + assertTrue(stateUpdater.drainRestoredActiveTasks(Duration.ZERO).isEmpty()); + } + + private void verifyUpdatingTasks(final Task... tasks) throws Exception { + if (tasks.length == 0) { + assertTrue(stateUpdater.getUpdatingTasks().isEmpty()); + } else { + final Set<Task> expectedUpdatingTasks = mkSet(tasks); + final Set<Task> updatingTasks = new HashSet<>(); + waitForCondition( + () -> { + updatingTasks.addAll(stateUpdater.getUpdatingTasks()); + return updatingTasks.containsAll(expectedUpdatingTasks); + }, + VERIFICATION_TIMEOUT, + "Did not get all updating task within the given timeout!" + ); + assertEquals(expectedUpdatingTasks.size(), updatingTasks.size()); + assertTrue(updatingTasks.stream().allMatch(task -> task.state() == State.RESTORING)); + } } private void verifyUpdatingStandbyTasks(final StandbyTask... tasks) throws Exception { @@ -454,27 +662,78 @@ class DefaultStateUpdaterTest { waitForCondition( () -> { standbyTasks.addAll(stateUpdater.getUpdatingStandbyTasks()); - return standbyTasks.size() == expectedStandbyTasks.size(); + return standbyTasks.containsAll(expectedStandbyTasks); }, VERIFICATION_TIMEOUT, "Did not see all standby task within the given timeout!" ); - assertTrue(standbyTasks.containsAll(expectedStandbyTasks)); - assertEquals(expectedStandbyTasks.size(), standbyTasks.stream().filter(t -> t.state() == State.RUNNING).count()); + assertEquals(expectedStandbyTasks.size(), standbyTasks.size()); + assertTrue(standbyTasks.stream().allMatch(task -> task.state() == State.RUNNING)); + } + + private void verifyRemovedTasks(final Task... tasks) throws Exception { + if (tasks.length == 0) { + assertTrue(stateUpdater.getRemovedTasks().isEmpty()); + } else { + final Set<Task> expectedRemovedTasks = mkSet(tasks); + final Set<Task> removedTasks = new HashSet<>(); + waitForCondition( + () -> { + removedTasks.addAll(stateUpdater.getRemovedTasks()); + return removedTasks.containsAll(mkSet(tasks)); + }, + VERIFICATION_TIMEOUT, + "Did not get all removed task within the given timeout!" + ); + assertEquals(expectedRemovedTasks.size(), removedTasks.size()); + assertTrue(removedTasks.stream() + .allMatch(task -> task.isActive() && task.state() == State.RESTORING + || !task.isActive() && task.state() == State.RUNNING)); + } } - private List<ExceptionAndTasks> getFailedTasks(final int expectedCount) throws Exception { + private void verifyDrainingRemovedTasks(final Task... tasks) throws Exception { + final Set<Task> expectedRemovedTasks = mkSet(tasks); + final Set<Task> removedTasks = new HashSet<>(); + waitForCondition( + () -> { + removedTasks.addAll(stateUpdater.drainRemovedTasks()); + return removedTasks.containsAll(mkSet(tasks)); + }, + VERIFICATION_TIMEOUT, + "Did not get all restored active task within the given timeout!" + ); + assertEquals(expectedRemovedTasks.size(), removedTasks.size()); + assertTrue(stateUpdater.drainRemovedTasks().isEmpty()); + } + + private void verifyExceptionsAndFailedTasks(final ExceptionAndTasks... exceptionsAndTasks) throws Exception { + final List<ExceptionAndTasks> expectedExceptionAndTasks = Arrays.asList(exceptionsAndTasks); final List<ExceptionAndTasks> failedTasks = new ArrayList<>(); waitForCondition( () -> { - failedTasks.addAll(stateUpdater.getFailedTasksAndExceptions()); - return failedTasks.size() >= expectedCount; + failedTasks.addAll(stateUpdater.getExceptionsAndFailedTasks()); + return failedTasks.containsAll(expectedExceptionAndTasks); }, VERIFICATION_TIMEOUT, - "Did not get enough failed tasks within the given timeout!" + "Did not get all exceptions and failed tasks within the given timeout!" ); + assertEquals(expectedExceptionAndTasks.size(), failedTasks.size()); + } - return failedTasks; + private void verifyDrainingExceptionsAndFailedTasks(final ExceptionAndTasks... exceptionsAndTasks) throws Exception { + final List<ExceptionAndTasks> expectedExceptionAndTasks = Arrays.asList(exceptionsAndTasks); + final List<ExceptionAndTasks> failedTasks = new ArrayList<>(); + waitForCondition( + () -> { + failedTasks.addAll(stateUpdater.drainExceptionsAndFailedTasks()); + return failedTasks.containsAll(expectedExceptionAndTasks); + }, + VERIFICATION_TIMEOUT, + "Did not get all exceptions and failed tasks within the given timeout!" + ); + assertEquals(expectedExceptionAndTasks.size(), failedTasks.size()); + assertTrue(stateUpdater.drainExceptionsAndFailedTasks().isEmpty()); } private StreamTask createActiveStatefulTaskInStateRestoring(final TaskId taskId, diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskAndActionTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskAndActionTest.java new file mode 100644 index 0000000000..39b927ee09 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskAndActionTest.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.streams.processor.TaskId; +import org.junit.jupiter.api.Test; + +import static org.apache.kafka.streams.processor.internals.TaskAndAction.Action.ADD; +import static org.apache.kafka.streams.processor.internals.TaskAndAction.Action.REMOVE; +import static org.apache.kafka.streams.processor.internals.TaskAndAction.createAddTask; +import static org.apache.kafka.streams.processor.internals.TaskAndAction.createRemoveTask; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +class TaskAndActionTest { + + @Test + public void shouldCreateAddTaskAction() { + final StreamTask task = mock(StreamTask.class); + + final TaskAndAction addTask = createAddTask(task); + + assertEquals(ADD, addTask.getAction()); + assertEquals(task, addTask.getTask()); + final Exception exception = assertThrows(IllegalStateException.class, addTask::getTaskId); + assertEquals("Action type ADD cannot have a task ID!", exception.getMessage()); + } + + @Test + public void shouldCreateRemoveTaskAction() { + final TaskId taskId = new TaskId(0, 0); + + final TaskAndAction removeTask = createRemoveTask(taskId); + + assertEquals(REMOVE, removeTask.getAction()); + assertEquals(taskId, removeTask.getTaskId()); + final Exception exception = assertThrows(IllegalStateException.class, removeTask::getTask); + assertEquals("Action type REMOVE cannot have a task!", exception.getMessage()); + } + + @Test + public void shouldThrowIfAddTaskActionIsCreatedWithNullTask() { + final Exception exception = assertThrows(NullPointerException.class, () -> createAddTask(null)); + assertTrue(exception.getMessage().contains("Task to add is null!")); + } + + @Test + public void shouldThrowIfRemoveTaskActionIsCreatedWithNullTaskId() { + final Exception exception = assertThrows(NullPointerException.class, () -> createRemoveTask(null)); + assertTrue(exception.getMessage().contains("Task ID of task to remove is null!")); + } +} \ No newline at end of file